import asyncio
import gc
import hashlib
import json
import logging
import os
import uuid
import warnings
from copy import deepcopy
from datetime import datetime, timezone
from typing import Any, Dict, Optional
from pydantic import ValidationError
from durag.configs.base import MemoryConfig, MemoryItem
from durag.configs.enums import MemoryType
from durag.configs.prompts import (
ADDITIVE_EXTRACTION_PROMPT,
AGENT_CONTEXT_SUFFIX,
PROCEDURAL_MEMORY_SYSTEM_PROMPT,
generate_additive_extraction_prompt,
)
from durag.exceptions import ValidationError as DuragValidationError
from durag.memory.base import MemoryBase
from durag.memory.setup import durag_dir, setup_config
from durag.memory.storage import SQLiteManager
from durag.memory.telemetry import DURAG_TELEMETRY, capture_event
from durag.memory.utils import (
extract_json,
parse_messages,
parse_vision_messages,
process_telemetry_filters,
remove_code_blocks,
)
from durag.utils.entity_extraction import extract_entities, extract_entities_batch
from durag.utils.factory import (
EmbedderFactory,
LlmFactory,
RerankerFactory,
VectorStoreFactory,
)
from durag.utils.lemmatization import lemmatize_for_bm25
from durag.utils.scoring import (
ENTITY_BOOST_WEIGHT,
get_bm25_params,
normalize_bm25,
score_and_rank,
)
# Suppress SWIG deprecation warnings globally
warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*SwigPy.*")
warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*swigvarlink.*")
# Initialize logger early for util functions
logger = logging.getLogger(__name__)
# Fields that hold runtime auth/connection objects and must be preserved.
# These are non-serializable objects (e.g. AWSV4SignerAuth, RequestsHttpConnection)
# needed by clients like OpenSearch — not sensitive strings to redact.
_RUNTIME_FIELDS = frozenset({
"http_auth",
"auth",
"connection_class",
"ssl_context",
})
# Fields that are known to contain sensitive secrets and must be redacted.
_SENSITIVE_FIELDS_EXACT = frozenset({
"api_key",
"secret_key",
"private_key",
"access_key",
"password",
"credentials",
"credential",
"secret",
"token",
"access_token",
"refresh_token",
"auth_token",
"session_token",
"client_secret",
"auth_client_secret",
"azure_client_secret",
"service_account_json",
"aws_session_token",
})
# Suffixes that indicate a field likely holds a secret value.
_SENSITIVE_SUFFIXES = (
"_password",
"_secret",
"_token",
"_credential",
"_credentials",
)
# Entity parameters that must be passed via filters, not top-level kwargs
ENTITY_PARAMS = frozenset({"user_id", "agent_id", "run_id"})
def _reject_top_level_entity_params(kwargs: Dict[str, Any], method_name: str) -> None:
"""Reject top-level entity parameters - must use filters instead."""
invalid_keys = ENTITY_PARAMS & set(kwargs.keys())
if invalid_keys:
raise ValueError(
f"Top-level entity parameters {invalid_keys} are not supported in {method_name}(). "
f"Use filters={{'user_id': '...'}} instead."
)
def _validate_and_trim_entity_id(value: Optional[str], name: str) -> Optional[str]:
"""
Validates and normalizes an entity ID.
- Trims leading/trailing whitespace
- Rejects empty or whitespace-only strings
- Rejects strings containing internal whitespace
Args:
value: The entity ID value to validate
name: The parameter name (for error messages)
Returns:
The trimmed entity ID, or None if input is None
Raises:
ValueError: If entity ID is invalid
"""
if value is None:
return None
trimmed = value.strip()
if trimmed == "":
raise ValueError(
f"Invalid {name}: cannot be empty or whitespace-only. Provide a valid identifier."
)
if any(c.isspace() for c in trimmed):
raise ValueError(
f"Invalid {name}: cannot contain whitespace. Provide a valid identifier without spaces."
)
return trimmed
def _validate_search_params(threshold: Optional[float] = None, top_k: Optional[int] = None) -> None:
"""
Validates search parameters.
Args:
threshold: Similarity threshold (must be between 0 and 1)
top_k: Number of results to return (must be non-negative integer)
Raises:
ValueError: If threshold or top_k are invalid
"""
if threshold is not None:
if not isinstance(threshold, (int, float)):
raise ValueError("threshold must be a valid number")
if threshold < 0 or threshold > 1:
raise ValueError(
f"Invalid threshold: {threshold}. Must be between 0 and 1 (inclusive)."
)
if top_k is not None:
if not isinstance(top_k, int) or isinstance(top_k, bool):
raise ValueError("top_k must be a valid integer")
if top_k < 0:
raise ValueError(
f"Invalid top_k: {top_k}. Must be a non-negative integer."
)
def _is_sensitive_field(field_name: str) -> bool:
"""Check if a field should be redacted for telemetry safety.
Uses a layered approach:
1. Runtime fields (allowlist) — always preserved, highest priority.
2. Exact deny list — known secret field names.
3. Suffix deny list — catches patterns like db_password, auth_secret, etc.
"""
name = field_name.lower().strip()
if name in _RUNTIME_FIELDS:
return False
if name in _SENSITIVE_FIELDS_EXACT:
return True
return any(name.endswith(suffix) for suffix in _SENSITIVE_SUFFIXES)
def _safe_deepcopy_config(config):
"""Safely deepcopy config, falling back to dict-based cloning for non-serializable objects."""
try:
return deepcopy(config)
except Exception as e:
logger.debug(f"Deepcopy failed, using dict-based cloning: {e}")
config_class = type(config)
if hasattr(config, "model_dump"):
try:
clone_dict = config.model_dump()
except Exception:
clone_dict = dict(config.__dict__)
else:
clone_dict = dict(config.__dict__)
# Restore runtime fields, redact sensitive ones
for field_name in list(clone_dict.keys()):
if field_name in _RUNTIME_FIELDS and hasattr(config, field_name):
clone_dict[field_name] = getattr(config, field_name)
elif _is_sensitive_field(field_name):
clone_dict[field_name] = None
try:
return config_class(**clone_dict)
except Exception:
logger.debug("Config reconstruction failed, returning shallow dict clone")
return type("Config", (), clone_dict)()
def _normalize_iso_timestamp_to_utc(timestamp: Optional[str]) -> Optional[str]:
"""Normalize timezone-aware ISO timestamps to UTC without rewriting naive values."""
if not timestamp:
return timestamp
try:
parsed = datetime.fromisoformat(timestamp)
except ValueError:
return timestamp
if parsed.tzinfo is None:
return timestamp
return parsed.astimezone(timezone.utc).isoformat()
def _build_filters_and_metadata(
*, # Enforce keyword-only arguments
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
actor_id: Optional[str] = None, # For query-time filtering
input_metadata: Optional[Dict[str, Any]] = None,
input_filters: Optional[Dict[str, Any]] = None,
) -> tuple[Dict[str, Any], Dict[str, Any]]:
"""
Constructs metadata for storage and filters for querying based on session and actor identifiers.
This helper supports multiple session identifiers (`user_id`, `agent_id`, and/or `run_id`)
for flexible session scoping and optionally narrows queries to a specific `actor_id`. It returns two dicts:
1. `base_metadata_template`: Used as a template for metadata when storing new memories.
It includes all provided session identifier(s) and any `input_metadata`.
2. `effective_query_filters`: Used for querying existing memories. It includes all
provided session identifier(s), any `input_filters`, and a resolved actor
identifier for targeted filtering if specified by any actor-related inputs.
Actor filtering precedence: explicit `actor_id` arg → `filters["actor_id"]`
This resolved actor ID is used for querying but is not added to `base_metadata_template`,
as the actor for storage is typically derived from message content at a later stage.
Args:
user_id (Optional[str]): User identifier, for session scoping.
agent_id (Optional[str]): Agent identifier, for session scoping.
run_id (Optional[str]): Run identifier, for session scoping.
actor_id (Optional[str]): Explicit actor identifier, used as a potential source for
actor-specific filtering. See actor resolution precedence in the main description.
input_metadata (Optional[Dict[str, Any]]): Base dictionary to be augmented with
session identifiers for the storage metadata template. Defaults to an empty dict.
input_filters (Optional[Dict[str, Any]]): Base dictionary to be augmented with
session and actor identifiers for query filters. Defaults to an empty dict.
Returns:
tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing:
- base_metadata_template (Dict[str, Any]): Metadata template for storing memories,
scoped to the provided session(s).
- effective_query_filters (Dict[str, Any]): Filters for querying memories,
scoped to the provided session(s) and potentially a resolved actor.
"""
base_metadata_template = deepcopy(input_metadata) if input_metadata else {}
effective_query_filters = deepcopy(input_filters) if input_filters else {}
# ---------- validate and add all provided session ids ----------
session_ids_provided = []
# Validate and trim entity IDs
user_id = _validate_and_trim_entity_id(user_id, "user_id")
agent_id = _validate_and_trim_entity_id(agent_id, "agent_id")
run_id = _validate_and_trim_entity_id(run_id, "run_id")
if user_id:
base_metadata_template["user_id"] = user_id
effective_query_filters["user_id"] = user_id
session_ids_provided.append("user_id")
if agent_id:
base_metadata_template["agent_id"] = agent_id
effective_query_filters["agent_id"] = agent_id
session_ids_provided.append("agent_id")
if run_id:
base_metadata_template["run_id"] = run_id
effective_query_filters["run_id"] = run_id
session_ids_provided.append("run_id")
if not session_ids_provided:
raise DuragValidationError(
message="At least one of 'user_id', 'agent_id', or 'run_id' must be provided.",
error_code="VALIDATION_001",
details={"provided_ids": {"user_id": user_id, "agent_id": agent_id, "run_id": run_id}},
suggestion="Please provide at least one identifier to scope the memory operation."
)
# ---------- optional actor filter ----------
resolved_actor_id = actor_id or effective_query_filters.get("actor_id")
if resolved_actor_id:
effective_query_filters["actor_id"] = resolved_actor_id
return base_metadata_template, effective_query_filters
def _build_session_scope(filters):
"""Build deterministic session scope string from entity IDs."""
parts = []
for key in sorted(["user_id", "agent_id", "run_id"]):
val = filters.get(key)
if val:
parts.append(f"{key}={val}")
return "&".join(parts)
setup_config()
logger = logging.getLogger(__name__)
[docs]
class Memory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()):
self.config = config
self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider,
self.config.embedder.config,
self.config.vector_store.config,
)
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
self.db = SQLiteManager(self.config.history_db_path)
self.collection_name = self.config.vector_store.config.collection_name
self.api_version = self.config.version
self.custom_instructions = self.config.custom_instructions
# Initialize reranker if configured
self.reranker = None
if config.reranker:
self.reranker = RerankerFactory.create(
config.reranker.provider,
config.reranker.config
)
# Entity store is initialized lazily on first use
self._entity_store = None
if DURAG_TELEMETRY:
# Create telemetry config manually to avoid deepcopy issues with thread locks
telemetry_config_dict = {}
if hasattr(self.config.vector_store.config, 'model_dump'):
# For pydantic models
telemetry_config_dict = self.config.vector_store.config.model_dump()
else:
# For other objects, manually copy common attributes
for attr in ['host', 'port', 'path', 'api_key', 'index_name', 'dimension', 'metric']:
if hasattr(self.config.vector_store.config, attr):
telemetry_config_dict[attr] = getattr(self.config.vector_store.config, attr)
# Override collection name for telemetry
telemetry_config_dict['collection_name'] = "duragmigrations"
# Set path for file-based vector stores
telemetry_config = _safe_deepcopy_config(self.config.vector_store.config)
if self.config.vector_store.provider in ["faiss", "qdrant"]:
provider_path = f"migrations_{self.config.vector_store.provider}"
telemetry_config_dict['path'] = os.path.join(durag_dir, provider_path)
os.makedirs(telemetry_config_dict['path'], exist_ok=True)
# Create the config object using the same class as the original
telemetry_config = self.config.vector_store.config.__class__(**telemetry_config_dict)
self._telemetry_vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, telemetry_config
)
capture_event("durag.init", self, {"sync_type": "sync"})
@property
def entity_store(self):
"""Lazily initialize entity store on first use."""
if self._entity_store is None:
entity_config = _safe_deepcopy_config(self.config.vector_store.config)
entity_collection = f"{self.collection_name}_entities"
# Set collection name on the cloned config
if hasattr(entity_config, 'collection_name'):
entity_config.collection_name = entity_collection
elif isinstance(entity_config, dict):
entity_config['collection_name'] = entity_collection
# For Qdrant, share the existing client to avoid RocksDB lock contention
# when using embedded mode (path=...). QdrantConfig.client takes precedence
# over host/port/path.
if self.config.vector_store.provider == "qdrant" and hasattr(self.vector_store, "client"):
if hasattr(entity_config, "client"):
entity_config.client = self.vector_store.client
elif isinstance(entity_config, dict):
entity_config["client"] = self.vector_store.client
self._entity_store = VectorStoreFactory.create(
self.config.vector_store.provider, entity_config
)
return self._entity_store
def _upsert_entity(self, entity_text, entity_type, memory_id, filters):
"""Upsert an entity into the entity store, linking it to a memory."""
try:
entity_embedding = self.embedding_model.embed(entity_text, "add")
search_filters = {k: v for k, v in filters.items() if k in ("user_id", "agent_id", "run_id") and v}
existing = self.entity_store.search(
query=entity_text,
vectors=entity_embedding,
top_k=1,
filters=search_filters,
)
if existing and existing[0].score >= 0.95:
# Update existing entity's linked_memory_ids
match = existing[0]
payload = match.payload or {}
linked_ids = payload.get("linked_memory_ids", [])
if memory_id not in linked_ids:
linked_ids.append(memory_id)
payload["linked_memory_ids"] = linked_ids
self.entity_store.update(
vector_id=match.id,
vector=None,
payload=payload,
)
else:
# Create new entity
entity_id = str(uuid.uuid4())
entity_payload = {
"data": entity_text,
"entity_type": entity_type,
"linked_memory_ids": [memory_id],
**{k: v for k, v in search_filters.items()},
}
self.entity_store.insert(
vectors=[entity_embedding],
ids=[entity_id],
payloads=[entity_payload],
)
except Exception as e:
logger.warning(f"Entity upsert failed for '{entity_text}': {e}")
def _remove_memory_from_entity_store(self, memory_id, filters):
"""Strip `memory_id` from every entity record scoped to `filters`.
For each entity whose `linked_memory_ids` contains `memory_id`:
- remove the id; if the list becomes empty, delete the entity record.
- otherwise re-embed the entity text and update the payload
(the vector store's update() requires a vector).
No-op if the entity store has never been initialized in this process.
Errors on individual entities are swallowed at debug level; outer
failures are swallowed at warning level so the primary delete/update
path is never broken by entity cleanup.
"""
if self._entity_store is None:
return
search_filters = {k: v for k, v in filters.items() if k in ("user_id", "agent_id", "run_id") and v}
try:
listed = self.entity_store.list(filters=search_filters, top_k=10000)
rows = listed[0] if isinstance(listed, (list, tuple)) and listed and isinstance(listed[0], list) else listed
for row in rows or []:
try:
payload = getattr(row, "payload", None) or {}
linked = payload.get("linked_memory_ids", [])
if not isinstance(linked, list) or memory_id not in linked:
continue
remaining = [mid for mid in linked if mid != memory_id]
if not remaining:
try:
self.entity_store.delete(vector_id=row.id)
except Exception as e:
logger.debug(f"Entity delete failed for id={row.id}: {e}")
else:
entity_text = payload.get("data")
if not isinstance(entity_text, str) or not entity_text:
logger.debug(f"Entity id={row.id} missing 'data'; skipping update during cleanup")
continue
try:
vec = self.embedding_model.embed(entity_text, "update")
except Exception as e:
logger.debug(f"Entity re-embed failed for '{entity_text}': {e}")
continue
new_payload = {**payload, "linked_memory_ids": remaining}
try:
self.entity_store.update(
vector_id=row.id,
vector=vec,
payload=new_payload,
)
except Exception as e:
logger.debug(f"Entity update failed for id={row.id}: {e}")
except Exception as e:
logger.debug(f"Entity cleanup error: {e}")
except Exception as e:
logger.warning(f"Entity store cleanup failed for memory_id={memory_id}: {e}")
def _link_entities_for_memory(self, memory_id, text, filters):
"""Extract entities from `text` and link them to `memory_id` in the
entity store, scoped to `filters`. Simpler single-memory variant of
Phase 7 in add(): per-entity search-then-update-or-insert via the
existing `_upsert_entity` helper. Non-fatal on any failure.
"""
try:
entities = extract_entities(text)
if not entities:
return
seen = set()
for entity_type, entity_text in entities:
key = entity_text.strip().lower()
if not key or key in seen:
continue
seen.add(key)
try:
self._upsert_entity(entity_text, entity_type, memory_id, filters)
except Exception as e:
logger.debug(f"Entity link failed for '{entity_text}': {e}")
except Exception as e:
logger.warning(f"Entity linking failed for memory_id={memory_id}: {e}")
[docs]
@classmethod
def from_config(cls, config_dict: Dict[str, Any]):
try:
config = cls._process_config(config_dict)
config = MemoryConfig(**config_dict)
except ValidationError as e:
logger.error(f"Configuration validation error: {e}")
raise
return cls(config)
@staticmethod
def _process_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
try:
return config_dict
except ValidationError as e:
logger.error(f"Configuration validation error: {e}")
raise
def _should_use_agent_memory_extraction(self, messages, metadata):
"""Determine whether to use agent memory extraction based on the logic:
- If agent_id is present and messages contain assistant role -> True
- Otherwise -> False
Args:
messages: List of message dictionaries
metadata: Metadata containing user_id, agent_id, etc.
Returns:
bool: True if should use agent memory extraction, False for user memory extraction
"""
# Check if agent_id is present in metadata
has_agent_id = metadata.get("agent_id") is not None
# Check if there are assistant role messages
has_assistant_messages = any(msg.get("role") == "assistant" for msg in messages)
# Use agent memory extraction if agent_id is present and there are assistant messages
return has_agent_id and has_assistant_messages
[docs]
def add(
self,
messages,
*,
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
infer: bool = True,
memory_type: Optional[str] = None,
prompt: Optional[str] = None,
):
"""
Create a new memory.
Adds new memories scoped to a single session id (e.g. `user_id`, `agent_id`, or `run_id`). One of those ids is required.
Args:
messages (str or List[Dict[str, str]]): The message content or list of messages
(e.g., `[{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi"}]`)
to be processed and stored.
user_id (str, optional): ID of the user creating the memory. Defaults to None.
agent_id (str, optional): ID of the agent creating the memory. Defaults to None.
run_id (str, optional): ID of the run creating the memory. Defaults to None.
metadata (dict, optional): Metadata to store with the memory. Defaults to None.
infer (bool, optional): If True (default), an LLM is used to extract key facts from
'messages' and decide whether to add, update, or delete related memories.
If False, 'messages' are added as raw memories directly.
memory_type (str, optional): Specifies the type of memory. Currently, only
`MemoryType.PROCEDURAL.value` ("procedural_memory") is explicitly handled for
creating procedural memories (typically requires 'agent_id'). Otherwise, memories
are treated as general conversational/factual memories.
prompt (str, optional): Prompt to use for the memory creation. Defaults to None.
Returns:
dict: A dictionary containing the result of the memory addition operation, typically
including a list of memory items affected (added, updated) under a "results" key.
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "event": "ADD"}]}`
Raises:
DuragValidationError: If input validation fails (invalid memory_type, messages format, etc.).
VectorStoreError: If vector store operations fail.
EmbeddingError: If embedding generation fails.
LLMError: If LLM operations fail.
DatabaseError: If database operations fail.
"""
processed_metadata, effective_filters = _build_filters_and_metadata(
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
input_metadata=metadata,
)
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
raise DuragValidationError(
message=f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories.",
error_code="VALIDATION_002",
details={"provided_type": memory_type, "valid_type": MemoryType.PROCEDURAL.value},
suggestion=f"Use '{MemoryType.PROCEDURAL.value}' to create procedural memories."
)
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
elif isinstance(messages, dict):
messages = [messages]
elif not isinstance(messages, list):
raise DuragValidationError(
message="messages must be str, dict, or list[dict]",
error_code="VALIDATION_003",
details={"provided_type": type(messages).__name__, "valid_types": ["str", "dict", "list[dict]"]},
suggestion="Convert your input to a string, dictionary, or list of dictionaries."
)
if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value:
results = self._create_procedural_memory(messages, metadata=processed_metadata, prompt=prompt)
return results
if self.config.llm.config.get("enable_vision"):
messages = parse_vision_messages(messages, self.llm, self.config.llm.config.get("vision_details"))
else:
messages = parse_vision_messages(messages)
vector_store_result = self._add_to_vector_store(messages, processed_metadata, effective_filters, infer, prompt=prompt)
return {"results": vector_store_result}
def _add_to_vector_store(self, messages, metadata, filters, infer, prompt=None):
if not infer:
returned_memories = []
for message_dict in messages:
if (
not isinstance(message_dict, dict)
or message_dict.get("role") is None
or message_dict.get("content") is None
):
logger.warning(f"Skipping invalid message format: {message_dict}")
continue
if message_dict["role"] == "system":
continue
per_msg_meta = deepcopy(metadata)
per_msg_meta["role"] = message_dict["role"]
actor_name = message_dict.get("name")
if actor_name:
per_msg_meta["actor_id"] = actor_name
msg_content = message_dict["content"]
msg_embeddings = self.embedding_model.embed(msg_content, "add")
mem_id = self._create_memory(msg_content, {msg_content: msg_embeddings}, per_msg_meta)
returned_memories.append(
{
"id": mem_id,
"memory": msg_content,
"event": "ADD",
"actor_id": actor_name if actor_name else None,
"role": message_dict["role"],
}
)
return returned_memories
# === V3 PHASED BATCH PIPELINE ===
# Phase 0: Context gathering
session_scope = _build_session_scope(filters)
last_messages = self.db.get_last_messages(session_scope, limit=10)
parsed_messages = parse_messages(messages)
# Phase 1: Existing memory retrieval
search_filters = {k: v for k, v in filters.items() if k in ("user_id", "agent_id", "run_id") and v}
query_embedding = self.embedding_model.embed(parsed_messages, "search")
existing_results = self.vector_store.search(
query=parsed_messages,
vectors=query_embedding,
top_k=10,
filters=search_filters,
)
# Map UUIDs to integers (anti-hallucination)
existing_memories = []
uuid_mapping = {}
for idx, mem in enumerate(existing_results):
uuid_mapping[str(idx)] = mem.id
existing_memories.append({"id": str(idx), "text": mem.payload.get("data", "")})
# Phase 2: LLM extraction (single call)
is_agent_scoped = bool(filters.get("agent_id")) and not filters.get("user_id")
system_prompt = ADDITIVE_EXTRACTION_PROMPT
if is_agent_scoped:
system_prompt += AGENT_CONTEXT_SUFFIX
custom_instr = prompt or self.custom_instructions
user_prompt = generate_additive_extraction_prompt(
existing_memories=existing_memories,
new_messages=parsed_messages,
last_k_messages=last_messages,
custom_instructions=custom_instr,
)
try:
response = self.llm.generate_response(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
response_format={"type": "json_object"},
)
except Exception as e:
logger.error(f"LLM extraction failed: {e}")
return []
# Parse response
try:
response = remove_code_blocks(response)
if not response or not response.strip():
extracted_memories = []
else:
try:
extracted_memories = json.loads(response, strict=False).get("memory", [])
except json.JSONDecodeError:
extracted_json = extract_json(response)
extracted_memories = json.loads(extracted_json, strict=False).get("memory", [])
except Exception as e:
logger.error(f"Error parsing extraction response: {e}")
extracted_memories = []
if not extracted_memories:
# Save messages even if nothing extracted
self.db.save_messages(messages, session_scope)
return []
# Phase 3: Batch embed all extracted memory texts
mem_texts = [m.get("text", "") for m in extracted_memories if m.get("text")]
try:
mem_embeddings_list = self.embedding_model.embed_batch(mem_texts, "add")
embed_map = dict(zip(mem_texts, mem_embeddings_list))
except Exception:
# Fallback: embed individually
embed_map = {}
for text in mem_texts:
try:
embed_map[text] = self.embedding_model.embed(text, "add")
except Exception as e:
logger.warning(f"Failed to embed memory text: {e}")
# Phase 4: Per-memory CPU processing + Phase 5: Hash dedup
# Build set of existing hashes for dedup
existing_hashes = set()
for mem in existing_results:
h = mem.payload.get("hash") if hasattr(mem, "payload") and mem.payload else None
if h:
existing_hashes.add(h)
records = [] # (memory_id, text, embedding, payload)
seen_hashes = set() # dedup within the current batch
for mem in extracted_memories:
text = mem.get("text")
if not text or text not in embed_map:
continue
mem_hash = hashlib.md5(text.encode()).hexdigest()
if mem_hash in existing_hashes or mem_hash in seen_hashes:
logger.debug(f"Skipping duplicate memory (hash match): {text[:50]}")
continue
seen_hashes.add(mem_hash)
text_lemmatized = lemmatize_for_bm25(text)
memory_id = str(uuid.uuid4())
mem_metadata = deepcopy(metadata)
mem_metadata["data"] = text
mem_metadata["text_lemmatized"] = text_lemmatized
mem_metadata["hash"] = mem_hash
if "created_at" not in mem_metadata:
mem_metadata["created_at"] = datetime.now(timezone.utc).isoformat()
mem_metadata["updated_at"] = mem_metadata["created_at"]
if mem.get("attributed_to"):
mem_metadata["attributed_to"] = mem["attributed_to"]
records.append((memory_id, text, embed_map[text], mem_metadata))
if not records:
self.db.save_messages(messages, session_scope)
return []
# Phase 6: Batch persist
all_vectors = [r[2] for r in records]
all_ids = [r[0] for r in records]
all_payloads = [r[3] for r in records]
try:
self.vector_store.insert(
vectors=all_vectors,
ids=all_ids,
payloads=all_payloads,
)
except Exception:
# Fallback: insert one by one
for mid, vec, pay in zip(all_ids, all_vectors, all_payloads):
try:
self.vector_store.insert(vectors=[vec], ids=[mid], payloads=[pay])
except Exception as e:
logger.error(f"Failed to insert memory {mid}: {e}")
# Batch history
history_records = [
{
"memory_id": r[0],
"old_memory": None,
"new_memory": r[1],
"event": "ADD",
"created_at": r[3].get("created_at"),
"is_deleted": 0,
}
for r in records
]
try:
self.db.batch_add_history(history_records)
except Exception:
# Fallback: add one by one
for hr in history_records:
try:
self.db.add_history(hr["memory_id"], None, hr["new_memory"], "ADD", created_at=hr.get("created_at"))
except Exception as e:
logger.error(f"Failed to add history for {hr['memory_id']}: {e}")
# Phase 7: Batch entity linking
try:
all_texts = [r[1] for r in records]
all_entities = extract_entities_batch(all_texts)
# 7a: Global dedup — collect unique entities across all memories
global_entities = {} # normalized_key -> (entity_type, entity_text, set of memory_ids)
for idx, (memory_id, text, embedding, payload) in enumerate(records):
entities = all_entities[idx] if idx < len(all_entities) else []
for entity_type, entity_text in entities:
key = entity_text.strip().lower()
if key in global_entities:
global_entities[key][2].add(memory_id)
else:
global_entities[key] = [entity_type, entity_text, {memory_id}]
if global_entities:
ordered_keys = list(global_entities.keys())
entity_texts = [global_entities[k][1] for k in ordered_keys]
# 7b: Single batch embed for all unique entities
try:
entity_embeddings = self.embedding_model.embed_batch(entity_texts, "add")
except Exception:
# Fallback: embed individually, use None for failures
entity_embeddings = []
for t in entity_texts:
try:
entity_embeddings.append(self.embedding_model.embed(t, "add"))
except Exception:
entity_embeddings.append(None)
# Filter out entities with failed embeddings
valid = [(i, k) for i, k in enumerate(ordered_keys) if entity_embeddings[i] is not None]
if valid:
valid_indices, valid_keys = zip(*valid)
valid_vectors = [entity_embeddings[i] for i in valid_indices]
# 7c: Batch search for existing entities
valid_texts = [global_entities[k][1] for k in valid_keys]
existing_matches = self.entity_store.search_batch(
queries=valid_texts,
vectors_list=valid_vectors,
top_k=1,
filters=search_filters,
)
# 7d: Separate into inserts vs updates
to_insert_vectors, to_insert_ids, to_insert_payloads = [], [], []
for j, key in enumerate(valid_keys):
entity_type, entity_text, memory_ids = global_entities[key]
matches = existing_matches[j] if j < len(existing_matches) else []
if matches and matches[0].score >= 0.95:
# Update existing entity
match = matches[0]
payload = match.payload or {}
linked = set(payload.get("linked_memory_ids", []))
linked |= memory_ids
payload["linked_memory_ids"] = sorted(linked)
try:
self.entity_store.update(
vector_id=match.id,
vector=None,
payload=payload,
)
except Exception as e:
logger.debug(f"Entity update failed for '{entity_text}': {e}")
else:
# New entity — collect for batch insert
to_insert_vectors.append(valid_vectors[j])
to_insert_ids.append(str(uuid.uuid4()))
to_insert_payloads.append({
"data": entity_text,
"entity_type": entity_type,
"linked_memory_ids": sorted(memory_ids),
**search_filters,
})
# 7e: Single batch insert for all new entities
if to_insert_vectors:
try:
self.entity_store.insert(
vectors=to_insert_vectors,
ids=to_insert_ids,
payloads=to_insert_payloads,
)
except Exception as e:
logger.warning(f"Batch entity insert failed: {e}")
except Exception as e:
logger.warning(f"Batch entity linking failed: {e}")
# Phase 8: Save messages + return
self.db.save_messages(messages, session_scope)
returned_memories = [
{"id": r[0], "memory": r[1], "event": "ADD"}
for r in records
]
keys, encoded_ids = process_telemetry_filters(filters)
capture_event(
"durag.add",
self,
{"version": self.api_version, "keys": keys, "encoded_ids": encoded_ids, "sync_type": "sync"},
)
return returned_memories
[docs]
def get(self, memory_id):
"""
Retrieve a memory by ID.
Args:
memory_id (str): ID of the memory to retrieve.
Returns:
dict: Retrieved memory.
"""
capture_event("durag.get", self, {"memory_id": memory_id, "sync_type": "sync"})
memory = self.vector_store.get(vector_id=memory_id)
if not memory:
return None
promoted_payload_keys = [
"user_id",
"agent_id",
"run_id",
"actor_id",
"role",
]
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", "text_lemmatized", "attributed_to", *promoted_payload_keys}
result_item = MemoryItem(
id=memory.id,
memory=memory.payload.get("data", ""),
hash=memory.payload.get("hash"),
created_at=memory.payload.get("created_at"),
updated_at=memory.payload.get("updated_at"),
).model_dump()
for key in promoted_payload_keys:
if key in memory.payload:
result_item[key] = memory.payload[key]
additional_metadata = {k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys}
if additional_metadata:
result_item["metadata"] = additional_metadata
return result_item
[docs]
def get_all(
self,
*,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 20,
**kwargs,
):
"""
List all memories.
Args:
filters (dict): Filter dict containing entity IDs and optional metadata filters.
Must contain at least one of: user_id, agent_id, run_id.
Example: filters={"user_id": "u1", "agent_id": "a1"}
top_k (int, optional): The maximum number of memories to return. Defaults to 20.
Returns:
dict: A dictionary containing a list of memories under the "results" key.
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}`
Raises:
ValueError: If filters doesn't contain at least one of user_id, agent_id, run_id,
or if top_k is invalid.
"""
# Reject top-level entity params - must use filters instead
_reject_top_level_entity_params(kwargs, "get_all")
# Validate top_k
_validate_search_params(top_k=top_k)
# Validate and trim entity IDs in filters
effective_filters = dict(filters) if filters else {}
if "user_id" in effective_filters:
effective_filters["user_id"] = _validate_and_trim_entity_id(
effective_filters["user_id"], "user_id"
)
if "agent_id" in effective_filters:
effective_filters["agent_id"] = _validate_and_trim_entity_id(
effective_filters["agent_id"], "agent_id"
)
if "run_id" in effective_filters:
effective_filters["run_id"] = _validate_and_trim_entity_id(
effective_filters["run_id"], "run_id"
)
# Validate filters contains at least one entity ID
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError(
"filters must contain at least one of: user_id, agent_id, run_id. "
"Example: filters={'user_id': 'u1'}"
)
limit = top_k
keys, encoded_ids = process_telemetry_filters(effective_filters)
capture_event(
"durag.get_all", self, {"limit": limit, "keys": keys, "encoded_ids": encoded_ids, "sync_type": "sync"}
)
all_memories_result = self._get_all_from_vector_store(effective_filters, limit)
return {"results": all_memories_result}
def _get_all_from_vector_store(self, filters, limit):
memories_result = self.vector_store.list(filters=filters, top_k=limit)
# Handle different vector store return formats by inspecting first element
if isinstance(memories_result, (tuple, list)) and len(memories_result) > 0:
first_element = memories_result[0]
# If first element is a container, unwrap one level
if isinstance(first_element, (list, tuple)):
actual_memories = first_element
else:
# First element is a memory object, structure is already flat
actual_memories = memories_result
else:
actual_memories = memories_result
promoted_payload_keys = [
"user_id",
"agent_id",
"run_id",
"actor_id",
"role",
]
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", "text_lemmatized", "attributed_to", *promoted_payload_keys}
formatted_memories = []
for mem in actual_memories:
memory_item_dict = MemoryItem(
id=mem.id,
memory=mem.payload.get("data", ""),
hash=mem.payload.get("hash"),
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
).model_dump(exclude={"score"})
for key in promoted_payload_keys:
if key in mem.payload:
memory_item_dict[key] = mem.payload[key]
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
if additional_metadata:
memory_item_dict["metadata"] = additional_metadata
formatted_memories.append(memory_item_dict)
return formatted_memories
[docs]
def search(
self,
query: str,
*,
top_k: int = 20,
filters: Optional[Dict[str, Any]] = None,
threshold: float = 0.1,
rerank: bool = False,
**kwargs,
):
"""
Searches for memories based on a query.
Args:
query (str): Query to search for.
top_k (int, optional): Maximum number of results to return. Defaults to 20.
filters (dict): Filter dict containing entity IDs and optional metadata filters.
Must contain at least one of: user_id, agent_id, run_id.
Example: filters={"user_id": "u1", "agent_id": "a1"}
Enhanced metadata filtering with operators:
- {"key": "value"} - exact match
- {"key": {"eq": "value"}} - equals
- {"key": {"ne": "value"}} - not equals
- {"key": {"in": ["val1", "val2"]}} - in list
- {"key": {"nin": ["val1", "val2"]}} - not in list
- {"key": {"gt": 10}} - greater than
- {"key": {"gte": 10}} - greater than or equal
- {"key": {"lt": 10}} - less than
- {"key": {"lte": 10}} - less than or equal
- {"key": {"contains": "text"}} - contains text
- {"key": {"icontains": "text"}} - case-insensitive contains
- {"key": "*"} - wildcard match (any value)
- {"AND": [filter1, filter2]} - logical AND
- {"OR": [filter1, filter2]} - logical OR
- {"NOT": [filter1]} - logical NOT
threshold (float, optional): Minimum score for a memory to be included. Defaults to 0.1.
rerank (bool, optional): Whether to rerank results. Defaults to False.
Returns:
dict: A dictionary containing the search results under a "results" key.
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}`
Raises:
ValueError: If filters doesn't contain at least one of user_id, agent_id, run_id,
or if threshold/top_k values are invalid.
"""
# Reject top-level entity params - must use filters instead
_reject_top_level_entity_params(kwargs, "search")
# Validate search parameters (before applying defaults)
_validate_search_params(threshold=threshold, top_k=top_k)
# Validate and trim entity IDs in filters
effective_filters = filters.copy() if filters else {}
if "user_id" in effective_filters:
effective_filters["user_id"] = _validate_and_trim_entity_id(
effective_filters["user_id"], "user_id"
)
if "agent_id" in effective_filters:
effective_filters["agent_id"] = _validate_and_trim_entity_id(
effective_filters["agent_id"], "agent_id"
)
if "run_id" in effective_filters:
effective_filters["run_id"] = _validate_and_trim_entity_id(
effective_filters["run_id"], "run_id"
)
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError(
"filters must contain at least one of: user_id, agent_id, run_id. "
"Example: filters={'user_id': 'u1'}"
)
limit = top_k
# Apply enhanced metadata filtering if advanced operators are detected
if self._has_advanced_operators(effective_filters):
processed_filters = self._process_metadata_filters(effective_filters)
# Remove logical/operator keys that have been reprocessed
for logical_key in ("AND", "OR", "NOT"):
effective_filters.pop(logical_key, None)
for fk in list(effective_filters.keys()):
if fk not in ("AND", "OR", "NOT", "user_id", "agent_id", "run_id") and isinstance(effective_filters.get(fk), dict):
effective_filters.pop(fk, None)
effective_filters.update(processed_filters)
keys, encoded_ids = process_telemetry_filters(effective_filters)
capture_event(
"durag.search",
self,
{
"limit": limit,
"version": self.api_version,
"keys": keys,
"encoded_ids": encoded_ids,
"sync_type": "sync",
"threshold": threshold,
"advanced_filters": bool(filters and self._has_advanced_operators(filters)),
},
)
original_memories = self._search_vector_store(query, effective_filters, limit, threshold)
# Apply reranking if enabled and reranker is available
if rerank and self.reranker and original_memories:
try:
reranked_memories = self.reranker.rerank(query, original_memories, limit)
original_memories = reranked_memories
except Exception as e:
logger.warning(f"Reranking failed, using original results: {e}")
return {"results": original_memories}
def _process_metadata_filters(self, metadata_filters: Dict[str, Any]) -> Dict[str, Any]:
"""
Process enhanced metadata filters and convert them to vector store compatible format.
Args:
metadata_filters: Enhanced metadata filters with operators
Returns:
Dict of processed filters compatible with vector store
"""
processed_filters = {}
def process_condition(key: str, condition: Any) -> Dict[str, Any]:
if not isinstance(condition, dict):
# Simple equality: {"key": "value"}
if condition == "*":
# Wildcard: match everything for this field (implementation depends on vector store)
return {key: "*"}
return {key: condition}
result = {}
for operator, value in condition.items():
# Map platform operators to universal format that can be translated by each vector store
operator_map = {
"eq": "eq", "ne": "ne", "gt": "gt", "gte": "gte",
"lt": "lt", "lte": "lte", "in": "in", "nin": "nin",
"contains": "contains", "icontains": "icontains"
}
if operator in operator_map:
result.setdefault(key, {})[operator_map[operator]] = value
else:
raise ValueError(f"Unsupported metadata filter operator: {operator}")
return result
def merge_filters(target: Dict[str, Any], source: Dict[str, Any]) -> None:
"""Merge source into target, deep-merging nested operator dicts for the same key."""
for key, value in source.items():
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
target[key].update(value)
else:
target[key] = value
for key, value in metadata_filters.items():
if key == "AND":
# Logical AND: combine multiple conditions
if not isinstance(value, list):
raise ValueError("AND operator requires a list of conditions")
for condition in value:
for sub_key, sub_value in condition.items():
merge_filters(processed_filters, process_condition(sub_key, sub_value))
elif key == "OR":
# Logical OR: Pass through to vector store for implementation-specific handling
if not isinstance(value, list) or not value:
raise ValueError("OR operator requires a non-empty list of conditions")
# Store OR conditions in a way that vector stores can interpret
processed_filters["$or"] = []
for condition in value:
or_condition = {}
for sub_key, sub_value in condition.items():
merge_filters(or_condition, process_condition(sub_key, sub_value))
processed_filters["$or"].append(or_condition)
elif key == "NOT":
# Logical NOT: Pass through to vector store for implementation-specific handling
if not isinstance(value, list) or not value:
raise ValueError("NOT operator requires a non-empty list of conditions")
processed_filters["$not"] = []
for condition in value:
not_condition = {}
for sub_key, sub_value in condition.items():
merge_filters(not_condition, process_condition(sub_key, sub_value))
processed_filters["$not"].append(not_condition)
else:
merge_filters(processed_filters, process_condition(key, value))
return processed_filters
def _has_advanced_operators(self, filters: Dict[str, Any]) -> bool:
"""
Check if filters contain advanced operators that need special processing.
Args:
filters: Dictionary of filters to check
Returns:
bool: True if advanced operators are detected
"""
if not isinstance(filters, dict):
return False
for key, value in filters.items():
# Check for platform-style logical operators
if key in ["AND", "OR", "NOT"]:
return True
# Check for comparison operators (without $ prefix for universal compatibility)
if isinstance(value, dict):
for op in value.keys():
if op in ["eq", "ne", "gt", "gte", "lt", "lte", "in", "nin", "contains", "icontains"]:
return True
# Check for wildcard values
if value == "*":
return True
return False
def _search_vector_store(self, query, filters, limit, threshold=0.1):
# Guard against None threshold (backward compat)
if threshold is None:
threshold = 0.1
# Step 1: Preprocess query
query_lemmatized = lemmatize_for_bm25(query)
query_entities = extract_entities(query)
# Step 2: Embed query
embeddings = self.embedding_model.embed(query, "search")
# Step 3: Semantic search (over-fetch for scoring pool)
internal_limit = max(limit * 4, 60)
semantic_results = self.vector_store.search(
query=query, vectors=embeddings, top_k=internal_limit, filters=filters
)
# Step 4: Keyword search (if store supports it)
keyword_results = self.vector_store.keyword_search(
query=query_lemmatized, top_k=internal_limit, filters=filters
)
# Step 5: Compute BM25 scores from keyword results
bm25_scores = {}
if keyword_results is not None:
midpoint, steepness = get_bm25_params(query, lemmatized=query_lemmatized)
for mem in keyword_results:
mem_id = str(mem.id) if hasattr(mem, 'id') else str(mem.get('id', ''))
raw_score = mem.score if hasattr(mem, 'score') else mem.get('score', 0)
if raw_score and raw_score > 0:
bm25_scores[mem_id] = normalize_bm25(raw_score, midpoint, steepness)
# Step 6: Compute entity boosts
entity_boosts = {}
if query_entities:
entity_boosts = self._compute_entity_boosts(query_entities, filters)
# Step 7: Build candidate set from semantic results
candidates = []
for mem in semantic_results:
mem_id = str(mem.id)
candidates.append({
"id": mem_id,
"score": mem.score,
"payload": mem.payload if hasattr(mem, 'payload') else {},
})
# Step 8: Score and rank
scored_results = score_and_rank(
semantic_results=candidates,
bm25_scores=bm25_scores,
entity_boosts=entity_boosts,
threshold=threshold,
top_k=limit,
)
# Step 9: Format results
promoted_payload_keys = [
"user_id",
"agent_id",
"run_id",
"actor_id",
"role",
]
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", "text_lemmatized", "attributed_to", *promoted_payload_keys}
original_memories = []
for scored in scored_results:
payload = scored.get("payload") or {}
if not payload.get("data"):
continue # Skip candidates with no payload data
memory_item_dict = MemoryItem(
id=scored["id"],
memory=payload.get("data", ""),
hash=payload.get("hash"),
created_at=payload.get("created_at"),
updated_at=payload.get("updated_at"),
score=scored["score"],
).model_dump()
for key in promoted_payload_keys:
if key in payload:
memory_item_dict[key] = payload[key]
additional_metadata = {k: v for k, v in payload.items() if k not in core_and_promoted_keys}
if additional_metadata:
if not memory_item_dict.get("metadata"):
memory_item_dict["metadata"] = {}
memory_item_dict["metadata"].update(additional_metadata)
original_memories.append(memory_item_dict)
return original_memories
def _compute_entity_boosts(self, query_entities, filters):
"""Compute per-memory entity boosts from entity store search.
For each extracted entity from the query:
1. Embed the entity text
2. Search the entity store (threshold >= 0.5)
3. For each matched entity, boost its linked memories
Returns:
Dict mapping memory_id (str) -> max entity boost [0, 0.5].
"""
# Deduplicate entities (max 8)
seen = set()
deduped = []
for entity_type, entity_text in query_entities[:8]:
key = entity_text.strip().lower()
if key and key not in seen:
seen.add(key)
deduped.append((entity_type, entity_text))
if not deduped:
return {}
search_filters = {k: v for k, v in filters.items() if k in ("user_id", "agent_id", "run_id") and v}
memory_boosts = {}
try:
for _, entity_text in deduped:
entity_embedding = self.embedding_model.embed(entity_text, "search")
matches = self.entity_store.search(
query=entity_text,
vectors=entity_embedding,
top_k=500,
filters=search_filters,
)
for match in matches:
similarity = match.score if hasattr(match, 'score') else 0.0
if similarity < 0.5:
continue
payload = match.payload if hasattr(match, 'payload') else {}
linked_memory_ids = payload.get("linked_memory_ids", [])
if not isinstance(linked_memory_ids, list):
continue
# Spread-attenuated boost: entities linking to many memories get attenuated
num_linked = max(len(linked_memory_ids), 1)
memory_count_weight = 1.0 / (1.0 + 0.001 * ((num_linked - 1) ** 2))
boost = similarity * ENTITY_BOOST_WEIGHT * memory_count_weight
for memory_id in linked_memory_ids:
if memory_id:
memory_key = str(memory_id)
memory_boosts[memory_key] = max(memory_boosts.get(memory_key, 0.0), boost)
except Exception as e:
logger.warning(f"Entity boost computation failed: {e}")
return memory_boosts
[docs]
def update(self, memory_id, data, metadata: Optional[Dict[str, Any]] = None):
"""
Update a memory by ID.
Args:
memory_id (str): ID of the memory to update.
data (str): New content to update the memory with.
metadata (dict, optional): Metadata to update with the memory. Defaults to None.
Returns:
dict: Success message indicating the memory was updated.
Example:
>>> m.update(memory_id="mem_123", data="Likes to play tennis on weekends")
{'message': 'Memory updated successfully!'}
"""
capture_event("durag.update", self, {"memory_id": memory_id, "sync_type": "sync"})
existing_embeddings = {data: self.embedding_model.embed(data, "update")}
self._update_memory(memory_id, data, existing_embeddings, metadata)
return {"message": "Memory updated successfully!"}
[docs]
def delete(self, memory_id):
"""
Delete a memory by ID.
Args:
memory_id (str): ID of the memory to delete.
"""
capture_event("durag.delete", self, {"memory_id": memory_id, "sync_type": "sync"})
existing_memory = self.vector_store.get(vector_id=memory_id)
if existing_memory is None:
raise ValueError(f"Memory with id {memory_id} not found")
self._delete_memory(memory_id, existing_memory)
return {"message": "Memory deleted successfully!"}
[docs]
def delete_all(self, user_id: Optional[str] = None, agent_id: Optional[str] = None, run_id: Optional[str] = None):
"""
Delete all memories.
Args:
user_id (str, optional): ID of the user to delete memories for. Defaults to None.
agent_id (str, optional): ID of the agent to delete memories for. Defaults to None.
run_id (str, optional): ID of the run to delete memories for. Defaults to None.
"""
filters: Dict[str, Any] = {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
if not filters:
raise ValueError(
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
)
keys, encoded_ids = process_telemetry_filters(filters)
capture_event("durag.delete_all", self, {"keys": keys, "encoded_ids": encoded_ids, "sync_type": "sync"})
# delete all vector memories and reset the collections
memories = self.vector_store.list(filters=filters)[0]
for memory in memories:
self._delete_memory(memory.id)
logger.info(f"Deleted {len(memories)} memories")
return {"message": "Memories deleted successfully!"}
[docs]
def history(self, memory_id):
"""
Get the history of changes for a memory by ID.
Args:
memory_id (str): ID of the memory to get history for.
Returns:
list: List of changes for the memory.
"""
capture_event("durag.history", self, {"memory_id": memory_id, "sync_type": "sync"})
return self.db.get_history(memory_id)
def _create_memory(self, data, existing_embeddings, metadata=None):
logger.debug(f"Creating memory with {data=}")
if data in existing_embeddings:
embeddings = existing_embeddings[data]
else:
embeddings = self.embedding_model.embed(data, memory_action="add")
memory_id = str(uuid.uuid4())
new_metadata = deepcopy(metadata) if metadata is not None else {}
new_metadata["data"] = data
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
if "created_at" not in new_metadata:
new_metadata["created_at"] = datetime.now(timezone.utc).isoformat()
new_metadata["updated_at"] = new_metadata["created_at"]
new_metadata["text_lemmatized"] = lemmatize_for_bm25(data)
self.vector_store.insert(
vectors=[embeddings],
ids=[memory_id],
payloads=[new_metadata],
)
self.db.add_history(
memory_id,
None,
data,
"ADD",
created_at=new_metadata.get("created_at"),
updated_at=new_metadata.get("updated_at"),
actor_id=new_metadata.get("actor_id"),
role=new_metadata.get("role"),
)
return memory_id
def _create_procedural_memory(self, messages, metadata=None, prompt=None):
"""
Create a procedural memory
Args:
messages (list): List of messages to create a procedural memory from.
metadata (dict): Metadata to create a procedural memory from.
prompt (str, optional): Prompt to use for the procedural memory creation. Defaults to None.
"""
logger.info("Creating procedural memory")
parsed_messages = [
{"role": "system", "content": prompt or PROCEDURAL_MEMORY_SYSTEM_PROMPT},
*messages,
{
"role": "user",
"content": "Create procedural memory of the above conversation.",
},
]
try:
procedural_memory = self.llm.generate_response(messages=parsed_messages)
procedural_memory = remove_code_blocks(procedural_memory)
except Exception as e:
logger.error(f"Error generating procedural memory summary: {e}")
raise
if metadata is None:
raise ValueError("Metadata cannot be done for procedural memory.")
metadata = {**metadata, "memory_type": MemoryType.PROCEDURAL.value}
embeddings = self.embedding_model.embed(procedural_memory, memory_action="add")
memory_id = self._create_memory(procedural_memory, {procedural_memory: embeddings}, metadata=metadata)
capture_event("durag._create_procedural_memory", self, {"memory_id": memory_id, "sync_type": "sync"})
result = {"results": [{"id": memory_id, "memory": procedural_memory, "event": "ADD"}]}
return result
def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
logger.info(f"Updating memory with {data=}")
try:
existing_memory = self.vector_store.get(vector_id=memory_id)
except Exception:
logger.error(f"Error getting memory with ID {memory_id} during update.")
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
if existing_memory is None:
raise ValueError(f"Memory with id {memory_id} not found. Please provide a valid 'memory_id'")
prev_value = existing_memory.payload.get("data")
new_metadata = deepcopy(metadata) if metadata is not None else {}
new_metadata["data"] = data
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
new_metadata["text_lemmatized"] = lemmatize_for_bm25(data)
new_metadata["created_at"] = existing_memory.payload.get("created_at")
new_metadata["updated_at"] = datetime.now(timezone.utc).isoformat()
# Preserve session identifiers from existing memory only if not provided in new metadata
if "user_id" not in new_metadata and "user_id" in existing_memory.payload:
new_metadata["user_id"] = existing_memory.payload["user_id"]
if "agent_id" not in new_metadata and "agent_id" in existing_memory.payload:
new_metadata["agent_id"] = existing_memory.payload["agent_id"]
if "run_id" not in new_metadata and "run_id" in existing_memory.payload:
new_metadata["run_id"] = existing_memory.payload["run_id"]
if "actor_id" in existing_memory.payload:
new_metadata["actor_id"] = existing_memory.payload["actor_id"]
if "role" not in new_metadata and "role" in existing_memory.payload:
new_metadata["role"] = existing_memory.payload["role"]
if data in existing_embeddings:
embeddings = existing_embeddings[data]
else:
embeddings = self.embedding_model.embed(data, "update")
self.vector_store.update(
vector_id=memory_id,
vector=embeddings,
payload=new_metadata,
)
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
self.db.add_history(
memory_id,
prev_value,
data,
"UPDATE",
created_at=new_metadata["created_at"],
updated_at=new_metadata["updated_at"],
actor_id=new_metadata.get("actor_id"),
role=new_metadata.get("role"),
)
# Entity-store cleanup: strip this memory's id from old-text entities,
# then re-extract entities from the new text and link them back.
session_filters = {k: new_metadata[k] for k in ("user_id", "agent_id", "run_id") if new_metadata.get(k)}
self._remove_memory_from_entity_store(memory_id, session_filters)
self._link_entities_for_memory(memory_id, data, session_filters)
return memory_id
def _delete_memory(self, memory_id, existing_memory=None):
logger.info(f"Deleting memory with {memory_id=}")
if existing_memory is None:
existing_memory = self.vector_store.get(vector_id=memory_id)
if existing_memory is None:
raise ValueError(f"Memory with id {memory_id} not found. Please provide a valid 'memory_id'")
prev_value = existing_memory.payload.get("data", "")
created_at = _normalize_iso_timestamp_to_utc(existing_memory.payload.get("created_at"))
updated_at = datetime.now(timezone.utc).isoformat()
payload = existing_memory.payload or {}
session_filters = {k: payload[k] for k in ("user_id", "agent_id", "run_id") if payload.get(k)}
self.vector_store.delete(vector_id=memory_id)
self.db.add_history(
memory_id,
prev_value,
None,
"DELETE",
created_at=created_at,
updated_at=updated_at,
actor_id=existing_memory.payload.get("actor_id"),
role=existing_memory.payload.get("role"),
is_deleted=1,
)
# Entity-store cleanup: strip this memory's id from any entity records
# that linked to it. Non-fatal — the helper swallows errors.
self._remove_memory_from_entity_store(memory_id, session_filters)
return memory_id
[docs]
def reset(self):
"""
Reset the memory store by:
Deletes the vector store collection
Resets the database
Recreates the vector store with a new client
"""
logger.warning("Resetting all memories")
if hasattr(self.db, "connection") and self.db.connection:
self.db.connection.execute("DROP TABLE IF EXISTS history")
self.db.connection.close()
self.db = SQLiteManager(self.config.history_db_path)
if hasattr(self.vector_store, "reset"):
self.vector_store = VectorStoreFactory.reset(self.vector_store)
else:
logger.warning("Vector store does not support reset. Skipping.")
self.vector_store.delete_col()
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
# Reset entity store if initialized
if self._entity_store is not None:
try:
self._entity_store.reset()
except Exception as e:
logger.warning(f"Failed to reset entity store: {e}")
self._entity_store = None
capture_event("durag.reset", self, {"sync_type": "sync"})
[docs]
def close(self):
"""Release resources held by this Memory instance (Qdrant locks, SQLite, etc.)."""
# Close vector store (releases Qdrant file lock in local mode)
if hasattr(self, "vector_store") and self.vector_store is not None:
self.vector_store.close()
self.vector_store = None
# Close telemetry vector store (separate Qdrant path)
if hasattr(self, "_telemetry_vector_store") and self._telemetry_vector_store is not None:
self._telemetry_vector_store.close()
self._telemetry_vector_store = None
# Close SQLite
if hasattr(self, "db") and self.db is not None:
self.db.close()
self.db = None
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __del__(self):
self.close()
[docs]
def chat(self, query):
raise NotImplementedError("Chat function not implemented yet.")
class AsyncMemory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()):
self.config = config
self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider,
self.config.embedder.config,
self.config.vector_store.config,
)
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
self.db = SQLiteManager(self.config.history_db_path)
self.collection_name = self.config.vector_store.config.collection_name
self.api_version = self.config.version
self.custom_instructions = self.config.custom_instructions
self._entity_store = None
# Initialize reranker if configured
self.reranker = None
if config.reranker:
self.reranker = RerankerFactory.create(
config.reranker.provider,
config.reranker.config
)
if DURAG_TELEMETRY:
telemetry_config = _safe_deepcopy_config(self.config.vector_store.config)
telemetry_config.collection_name = "duragmigrations"
if self.config.vector_store.provider in ["faiss", "qdrant"]:
provider_path = f"migrations_{self.config.vector_store.provider}"
telemetry_config.path = os.path.join(durag_dir, provider_path)
os.makedirs(telemetry_config.path, exist_ok=True)
self._telemetry_vector_store = VectorStoreFactory.create(self.config.vector_store.provider, telemetry_config)
capture_event("durag.init", self, {"sync_type": "async"})
@property
def entity_store(self):
"""Lazily initialize entity store on first use."""
if self._entity_store is None:
entity_config = _safe_deepcopy_config(self.config.vector_store.config)
entity_collection = f"{self.collection_name}_entities"
if hasattr(entity_config, 'collection_name'):
entity_config.collection_name = entity_collection
elif isinstance(entity_config, dict):
entity_config['collection_name'] = entity_collection
# For Qdrant, share the existing client to avoid RocksDB lock contention
# when using embedded mode (path=...). QdrantConfig.client takes precedence
# over host/port/path.
if self.config.vector_store.provider == "qdrant" and hasattr(self.vector_store, "client"):
if hasattr(entity_config, "client"):
entity_config.client = self.vector_store.client
elif isinstance(entity_config, dict):
entity_config["client"] = self.vector_store.client
self._entity_store = VectorStoreFactory.create(
self.config.vector_store.provider, entity_config
)
return self._entity_store
async def _upsert_entity_async(self, entity_text, entity_type, memory_id, filters):
"""Async variant of `_upsert_entity` — per-entity search-then-update-or-insert."""
try:
entity_embedding = await asyncio.to_thread(self.embedding_model.embed, entity_text, "add")
search_filters = {k: v for k, v in filters.items() if k in ("user_id", "agent_id", "run_id") and v}
existing = await asyncio.to_thread(
self.entity_store.search,
query=entity_text,
vectors=entity_embedding,
top_k=1,
filters=search_filters,
)
if existing and existing[0].score >= 0.95:
match = existing[0]
payload = match.payload or {}
linked_ids = payload.get("linked_memory_ids", [])
if memory_id not in linked_ids:
linked_ids.append(memory_id)
payload["linked_memory_ids"] = linked_ids
await asyncio.to_thread(
self.entity_store.update,
vector_id=match.id,
vector=None,
payload=payload,
)
else:
entity_id = str(uuid.uuid4())
entity_payload = {
"data": entity_text,
"entity_type": entity_type,
"linked_memory_ids": [memory_id],
**{k: v for k, v in search_filters.items()},
}
await asyncio.to_thread(
self.entity_store.insert,
vectors=[entity_embedding],
ids=[entity_id],
payloads=[entity_payload],
)
except Exception as e:
logger.warning(f"Entity upsert failed for '{entity_text}' (async): {e}")
async def _remove_memory_from_entity_store(self, memory_id, filters):
"""Async variant of `Memory._remove_memory_from_entity_store`."""
if self._entity_store is None:
return
search_filters = {k: v for k, v in filters.items() if k in ("user_id", "agent_id", "run_id") and v}
try:
listed = await asyncio.to_thread(self.entity_store.list, filters=search_filters, top_k=10000)
rows = listed[0] if isinstance(listed, (list, tuple)) and listed and isinstance(listed[0], list) else listed
for row in rows or []:
try:
payload = getattr(row, "payload", None) or {}
linked = payload.get("linked_memory_ids", [])
if not isinstance(linked, list) or memory_id not in linked:
continue
remaining = [mid for mid in linked if mid != memory_id]
if not remaining:
try:
await asyncio.to_thread(self.entity_store.delete, vector_id=row.id)
except Exception as e:
logger.debug(f"Entity delete failed for id={row.id} (async): {e}")
else:
entity_text = payload.get("data")
if not isinstance(entity_text, str) or not entity_text:
logger.debug(f"Entity id={row.id} missing 'data'; skipping update during cleanup (async)")
continue
try:
vec = await asyncio.to_thread(self.embedding_model.embed, entity_text, "update")
except Exception as e:
logger.debug(f"Entity re-embed failed for '{entity_text}' (async): {e}")
continue
new_payload = {**payload, "linked_memory_ids": remaining}
try:
await asyncio.to_thread(
self.entity_store.update,
vector_id=row.id,
vector=vec,
payload=new_payload,
)
except Exception as e:
logger.debug(f"Entity update failed for id={row.id} (async): {e}")
except Exception as e:
logger.debug(f"Entity cleanup error (async): {e}")
except Exception as e:
logger.warning(f"Entity store cleanup failed for memory_id={memory_id} (async): {e}")
async def _link_entities_for_memory(self, memory_id, text, filters):
"""Async variant of `Memory._link_entities_for_memory`."""
try:
entities = await asyncio.to_thread(extract_entities, text)
if not entities:
return
seen = set()
for entity_type, entity_text in entities:
key = entity_text.strip().lower()
if not key or key in seen:
continue
seen.add(key)
try:
await self._upsert_entity_async(entity_text, entity_type, memory_id, filters)
except Exception as e:
logger.debug(f"Entity link failed for '{entity_text}' (async): {e}")
except Exception as e:
logger.warning(f"Entity linking failed for memory_id={memory_id} (async): {e}")
@classmethod
def from_config(cls, config_dict: Dict[str, Any]):
try:
config = cls._process_config(config_dict)
config = MemoryConfig(**config_dict)
except ValidationError as e:
logger.error(f"Configuration validation error: {e}")
raise
return cls(config)
@staticmethod
def _process_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
try:
return config_dict
except ValidationError as e:
logger.error(f"Configuration validation error: {e}")
raise
def _should_use_agent_memory_extraction(self, messages, metadata):
"""Determine whether to use agent memory extraction based on the logic:
- If agent_id is present and messages contain assistant role -> True
- Otherwise -> False
Args:
messages: List of message dictionaries
metadata: Metadata containing user_id, agent_id, etc.
Returns:
bool: True if should use agent memory extraction, False for user memory extraction
"""
# Check if agent_id is present in metadata
has_agent_id = metadata.get("agent_id") is not None
# Check if there are assistant role messages
has_assistant_messages = any(msg.get("role") == "assistant" for msg in messages)
# Use agent memory extraction if agent_id is present and there are assistant messages
return has_agent_id and has_assistant_messages
async def add(
self,
messages,
*,
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
infer: bool = True,
memory_type: Optional[str] = None,
prompt: Optional[str] = None,
llm=None,
):
"""
Create a new memory asynchronously.
Args:
messages (str or List[Dict[str, str]]): Messages to store in the memory.
user_id (str, optional): ID of the user creating the memory.
agent_id (str, optional): ID of the agent creating the memory. Defaults to None.
run_id (str, optional): ID of the run creating the memory. Defaults to None.
metadata (dict, optional): Metadata to store with the memory. Defaults to None.
infer (bool, optional): Whether to infer the memories. Defaults to True.
memory_type (str, optional): Type of memory to create. Defaults to None.
Pass "procedural_memory" to create procedural memories.
prompt (str, optional): Prompt to use for the memory creation. Defaults to None.
llm (BaseChatModel, optional): LLM class to use for generating procedural memories. Defaults to None. Useful when user is using LangChain ChatModel.
Returns:
dict: A dictionary containing the result of the memory addition operation.
"""
processed_metadata, effective_filters = _build_filters_and_metadata(
user_id=user_id, agent_id=agent_id, run_id=run_id, input_metadata=metadata
)
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
raise ValueError(
f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories."
)
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
elif isinstance(messages, dict):
messages = [messages]
elif not isinstance(messages, list):
raise DuragValidationError(
message="messages must be str, dict, or list[dict]",
error_code="VALIDATION_003",
details={"provided_type": type(messages).__name__, "valid_types": ["str", "dict", "list[dict]"]},
suggestion="Convert your input to a string, dictionary, or list of dictionaries."
)
if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value:
results = await self._create_procedural_memory(
messages, metadata=processed_metadata, prompt=prompt, llm=llm
)
return results
if self.config.llm.config.get("enable_vision"):
messages = parse_vision_messages(messages, self.llm, self.config.llm.config.get("vision_details"))
else:
messages = parse_vision_messages(messages)
vector_store_result = await self._add_to_vector_store(messages, processed_metadata, effective_filters, infer, prompt=prompt)
return {"results": vector_store_result}
async def _add_to_vector_store(
self,
messages: list,
metadata: dict,
effective_filters: dict,
infer: bool,
prompt: Optional[str] = None,
):
if not infer:
returned_memories = []
for message_dict in messages:
if (
not isinstance(message_dict, dict)
or message_dict.get("role") is None
or message_dict.get("content") is None
):
logger.warning(f"Skipping invalid message format (async): {message_dict}")
continue
if message_dict["role"] == "system":
continue
per_msg_meta = deepcopy(metadata)
per_msg_meta["role"] = message_dict["role"]
actor_name = message_dict.get("name")
if actor_name:
per_msg_meta["actor_id"] = actor_name
msg_content = message_dict["content"]
msg_embeddings = await asyncio.to_thread(self.embedding_model.embed, msg_content, "add")
mem_id = await self._create_memory(msg_content, {msg_content: msg_embeddings}, per_msg_meta)
returned_memories.append(
{
"id": mem_id,
"memory": msg_content,
"event": "ADD",
"actor_id": actor_name if actor_name else None,
"role": message_dict["role"],
}
)
return returned_memories
# === V3 PHASED BATCH PIPELINE (async) ===
# Phase 0: Context gathering
session_scope = _build_session_scope(effective_filters)
last_messages = await asyncio.to_thread(self.db.get_last_messages, session_scope, 10)
parsed_messages = parse_messages(messages)
# Phase 1: Existing memory retrieval
search_filters = {k: v for k, v in effective_filters.items() if k in ("user_id", "agent_id", "run_id") and v}
query_embedding = await asyncio.to_thread(self.embedding_model.embed, parsed_messages, "search")
existing_results = await asyncio.to_thread(
self.vector_store.search,
query=parsed_messages,
vectors=query_embedding,
top_k=10,
filters=search_filters,
)
# Map UUIDs to integers (anti-hallucination)
existing_memories = []
uuid_mapping = {}
for idx, mem in enumerate(existing_results):
uuid_mapping[str(idx)] = mem.id
existing_memories.append({"id": str(idx), "text": mem.payload.get("data", "")})
# Phase 2: LLM extraction (single call)
is_agent_scoped = bool(effective_filters.get("agent_id")) and not effective_filters.get("user_id")
system_prompt = ADDITIVE_EXTRACTION_PROMPT
if is_agent_scoped:
system_prompt += AGENT_CONTEXT_SUFFIX
custom_instr = prompt or self.custom_instructions
user_prompt = generate_additive_extraction_prompt(
existing_memories=existing_memories,
new_messages=parsed_messages,
last_k_messages=last_messages,
custom_instructions=custom_instr,
)
try:
response = await asyncio.to_thread(
self.llm.generate_response,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
response_format={"type": "json_object"},
)
except Exception as e:
logger.error(f"LLM extraction failed (async): {e}")
return []
# Parse response
try:
response = remove_code_blocks(response)
if not response or not response.strip():
extracted_memories = []
else:
try:
extracted_memories = json.loads(response, strict=False).get("memory", [])
except json.JSONDecodeError:
extracted_json = extract_json(response)
extracted_memories = json.loads(extracted_json, strict=False).get("memory", [])
except Exception as e:
logger.error(f"Error parsing extraction response (async): {e}")
extracted_memories = []
if not extracted_memories:
await asyncio.to_thread(self.db.save_messages, messages, session_scope)
return []
# Phase 3: Batch embed all extracted memory texts
mem_texts = [m.get("text", "") for m in extracted_memories if m.get("text")]
try:
mem_embeddings_list = await asyncio.to_thread(self.embedding_model.embed_batch, mem_texts, "add")
embed_map = dict(zip(mem_texts, mem_embeddings_list))
except Exception:
embed_map = {}
for text in mem_texts:
try:
embed_map[text] = await asyncio.to_thread(self.embedding_model.embed, text, "add")
except Exception as e:
logger.warning(f"Failed to embed memory text (async): {e}")
# Phase 4: Per-memory CPU processing + Phase 5: Hash dedup
existing_hashes = set()
for mem in existing_results:
h = mem.payload.get("hash") if hasattr(mem, "payload") and mem.payload else None
if h:
existing_hashes.add(h)
records = []
seen_hashes = set()
for mem in extracted_memories:
text = mem.get("text")
if not text or text not in embed_map:
continue
mem_hash = hashlib.md5(text.encode()).hexdigest()
if mem_hash in existing_hashes or mem_hash in seen_hashes:
logger.debug(f"Skipping duplicate memory (hash match, async): {text[:50]}")
continue
seen_hashes.add(mem_hash)
text_lemmatized = lemmatize_for_bm25(text)
memory_id = str(uuid.uuid4())
mem_metadata = deepcopy(metadata)
mem_metadata["data"] = text
mem_metadata["text_lemmatized"] = text_lemmatized
mem_metadata["hash"] = mem_hash
if "created_at" not in mem_metadata:
mem_metadata["created_at"] = datetime.now(timezone.utc).isoformat()
mem_metadata["updated_at"] = mem_metadata["created_at"]
if mem.get("attributed_to"):
mem_metadata["attributed_to"] = mem["attributed_to"]
records.append((memory_id, text, embed_map[text], mem_metadata))
if not records:
await asyncio.to_thread(self.db.save_messages, messages, session_scope)
return []
# Phase 6: Batch persist
all_vectors = [r[2] for r in records]
all_ids = [r[0] for r in records]
all_payloads = [r[3] for r in records]
try:
await asyncio.to_thread(
self.vector_store.insert,
vectors=all_vectors,
ids=all_ids,
payloads=all_payloads,
)
except Exception:
for mid, vec, pay in zip(all_ids, all_vectors, all_payloads):
try:
await asyncio.to_thread(self.vector_store.insert, vectors=[vec], ids=[mid], payloads=[pay])
except Exception as e:
logger.error(f"Failed to insert memory {mid} (async): {e}")
# Batch history
history_records = [
{
"memory_id": r[0],
"old_memory": None,
"new_memory": r[1],
"event": "ADD",
"created_at": r[3].get("created_at"),
"is_deleted": 0,
}
for r in records
]
try:
await asyncio.to_thread(self.db.batch_add_history, history_records)
except Exception:
for hr in history_records:
try:
await asyncio.to_thread(
self.db.add_history, hr["memory_id"], None, hr["new_memory"], "ADD",
created_at=hr.get("created_at")
)
except Exception as e:
logger.error(f"Failed to add history for {hr['memory_id']} (async): {e}")
# Phase 7: Batch entity linking
try:
all_texts = [r[1] for r in records]
all_entities = await asyncio.to_thread(extract_entities_batch, all_texts)
# 7a: Global dedup
global_entities = {}
for idx, (memory_id, text, embedding, payload) in enumerate(records):
entities = all_entities[idx] if idx < len(all_entities) else []
for entity_type, entity_text in entities:
key = entity_text.strip().lower()
if key in global_entities:
global_entities[key][2].add(memory_id)
else:
global_entities[key] = [entity_type, entity_text, {memory_id}]
if global_entities:
ordered_keys = list(global_entities.keys())
entity_texts = [global_entities[k][1] for k in ordered_keys]
# 7b: Batch embed entities
try:
entity_embeddings = await asyncio.to_thread(self.embedding_model.embed_batch, entity_texts, "add")
except Exception:
entity_embeddings = []
for t in entity_texts:
try:
entity_embeddings.append(await asyncio.to_thread(self.embedding_model.embed, t, "add"))
except Exception:
entity_embeddings.append(None)
valid = [(i, k) for i, k in enumerate(ordered_keys) if entity_embeddings[i] is not None]
if valid:
valid_indices, valid_keys = zip(*valid)
valid_vectors = [entity_embeddings[i] for i in valid_indices]
# 7c: Batch search for existing entities
valid_texts = [global_entities[k][1] for k in valid_keys]
existing_matches = await asyncio.to_thread(
self.entity_store.search_batch,
queries=valid_texts,
vectors_list=valid_vectors,
top_k=1,
filters=search_filters,
)
# 7d: Separate into inserts vs updates
to_insert_vectors, to_insert_ids, to_insert_payloads = [], [], []
for j, key in enumerate(valid_keys):
entity_type, entity_text, memory_ids = global_entities[key]
matches = existing_matches[j] if j < len(existing_matches) else []
if matches and matches[0].score >= 0.95:
match = matches[0]
payload = match.payload or {}
linked = set(payload.get("linked_memory_ids", []))
linked |= memory_ids
payload["linked_memory_ids"] = sorted(linked)
try:
await asyncio.to_thread(
self.entity_store.update,
vector_id=match.id,
vector=None,
payload=payload,
)
except Exception as e:
logger.debug(f"Entity update failed for '{entity_text}' (async): {e}")
else:
to_insert_vectors.append(valid_vectors[j])
to_insert_ids.append(str(uuid.uuid4()))
to_insert_payloads.append({
"data": entity_text,
"entity_type": entity_type,
"linked_memory_ids": sorted(memory_ids),
**search_filters,
})
# 7e: Batch insert new entities
if to_insert_vectors:
try:
await asyncio.to_thread(
self.entity_store.insert,
vectors=to_insert_vectors,
ids=to_insert_ids,
payloads=to_insert_payloads,
)
except Exception as e:
logger.warning(f"Batch entity insert failed (async): {e}")
except Exception as e:
logger.warning(f"Batch entity linking failed (async): {e}")
# Phase 8: Save messages + return
await asyncio.to_thread(self.db.save_messages, messages, session_scope)
returned_memories = [
{"id": r[0], "memory": r[1], "event": "ADD"}
for r in records
]
keys, encoded_ids = process_telemetry_filters(effective_filters)
capture_event(
"durag.add",
self,
{"version": self.api_version, "keys": keys, "encoded_ids": encoded_ids, "sync_type": "async"},
)
return returned_memories
async def get(self, memory_id):
"""
Retrieve a memory by ID asynchronously.
Args:
memory_id (str): ID of the memory to retrieve.
Returns:
dict: Retrieved memory.
"""
capture_event("durag.get", self, {"memory_id": memory_id, "sync_type": "async"})
memory = await asyncio.to_thread(self.vector_store.get, vector_id=memory_id)
if not memory:
return None
promoted_payload_keys = [
"user_id",
"agent_id",
"run_id",
"actor_id",
"role",
]
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", "text_lemmatized", "attributed_to", *promoted_payload_keys}
result_item = MemoryItem(
id=memory.id,
memory=memory.payload.get("data", ""),
hash=memory.payload.get("hash"),
created_at=memory.payload.get("created_at"),
updated_at=memory.payload.get("updated_at"),
).model_dump()
for key in promoted_payload_keys:
if key in memory.payload:
result_item[key] = memory.payload[key]
additional_metadata = {k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys}
if additional_metadata:
result_item["metadata"] = additional_metadata
return result_item
async def get_all(
self,
*,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 20,
**kwargs,
):
"""
List all memories.
Args:
filters (dict): Filter dict containing entity IDs and optional metadata filters.
Must contain at least one of: user_id, agent_id, run_id.
Example: filters={"user_id": "u1", "agent_id": "a1"}
top_k (int, optional): The maximum number of memories to return. Defaults to 20.
Returns:
dict: A dictionary containing a list of memories under the "results" key.
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}`
Raises:
ValueError: If filters doesn't contain at least one of user_id, agent_id, run_id,
or if top_k is invalid.
"""
# Reject top-level entity params - must use filters instead
_reject_top_level_entity_params(kwargs, "get_all")
# Validate top_k
_validate_search_params(top_k=top_k)
# Validate and trim entity IDs in filters
effective_filters = dict(filters) if filters else {}
if "user_id" in effective_filters:
effective_filters["user_id"] = _validate_and_trim_entity_id(
effective_filters["user_id"], "user_id"
)
if "agent_id" in effective_filters:
effective_filters["agent_id"] = _validate_and_trim_entity_id(
effective_filters["agent_id"], "agent_id"
)
if "run_id" in effective_filters:
effective_filters["run_id"] = _validate_and_trim_entity_id(
effective_filters["run_id"], "run_id"
)
# Validate filters contains at least one entity ID
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError(
"filters must contain at least one of: user_id, agent_id, run_id. "
"Example: filters={'user_id': 'u1'}"
)
limit = top_k
keys, encoded_ids = process_telemetry_filters(effective_filters)
capture_event(
"durag.get_all", self, {"limit": limit, "keys": keys, "encoded_ids": encoded_ids, "sync_type": "async"}
)
all_memories_result = await self._get_all_from_vector_store(effective_filters, limit)
return {"results": all_memories_result}
async def _get_all_from_vector_store(self, filters, limit):
memories_result = await asyncio.to_thread(self.vector_store.list, filters=filters, top_k=limit)
# Handle different vector store return formats by inspecting first element
if isinstance(memories_result, (tuple, list)) and len(memories_result) > 0:
first_element = memories_result[0]
# If first element is a container, unwrap one level
if isinstance(first_element, (list, tuple)):
actual_memories = first_element
else:
# First element is a memory object, structure is already flat
actual_memories = memories_result
else:
actual_memories = memories_result
promoted_payload_keys = [
"user_id",
"agent_id",
"run_id",
"actor_id",
"role",
]
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", "text_lemmatized", "attributed_to", *promoted_payload_keys}
formatted_memories = []
for mem in actual_memories:
memory_item_dict = MemoryItem(
id=mem.id,
memory=mem.payload.get("data", ""),
hash=mem.payload.get("hash"),
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
).model_dump(exclude={"score"})
for key in promoted_payload_keys:
if key in mem.payload:
memory_item_dict[key] = mem.payload[key]
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
if additional_metadata:
memory_item_dict["metadata"] = additional_metadata
formatted_memories.append(memory_item_dict)
return formatted_memories
async def search(
self,
query: str,
*,
top_k: int = 20,
filters: Optional[Dict[str, Any]] = None,
threshold: float = 0.1,
rerank: bool = False,
**kwargs,
):
"""
Searches for memories based on a query.
Args:
query (str): Query to search for.
top_k (int, optional): Maximum number of results to return. Defaults to 20.
filters (dict): Filter dict containing entity IDs and optional metadata filters.
Must contain at least one of: user_id, agent_id, run_id.
Example: filters={"user_id": "u1", "agent_id": "a1"}
Enhanced metadata filtering with operators:
- {"key": "value"} - exact match
- {"key": {"eq": "value"}} - equals
- {"key": {"ne": "value"}} - not equals
- {"key": {"in": ["val1", "val2"]}} - in list
- {"key": {"nin": ["val1", "val2"]}} - not in list
- {"key": {"gt": 10}} - greater than
- {"key": {"gte": 10}} - greater than or equal
- {"key": {"lt": 10}} - less than
- {"key": {"lte": 10}} - less than or equal
- {"key": {"contains": "text"}} - contains text
- {"key": {"icontains": "text"}} - case-insensitive contains
- {"key": "*"} - wildcard match (any value)
- {"AND": [filter1, filter2]} - logical AND
- {"OR": [filter1, filter2]} - logical OR
- {"NOT": [filter1]} - logical NOT
threshold (float, optional): Minimum score for a memory to be included. Defaults to 0.1.
rerank (bool, optional): Whether to rerank results. Defaults to False.
Returns:
dict: A dictionary containing the search results under a "results" key.
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}`
Raises:
ValueError: If filters doesn't contain at least one of user_id, agent_id, run_id,
or if threshold/top_k values are invalid.
"""
# Reject top-level entity params - must use filters instead
_reject_top_level_entity_params(kwargs, "search")
# Validate search parameters (before applying defaults)
_validate_search_params(threshold=threshold, top_k=top_k)
# Validate and trim entity IDs in filters
effective_filters = filters.copy() if filters else {}
if "user_id" in effective_filters:
effective_filters["user_id"] = _validate_and_trim_entity_id(
effective_filters["user_id"], "user_id"
)
if "agent_id" in effective_filters:
effective_filters["agent_id"] = _validate_and_trim_entity_id(
effective_filters["agent_id"], "agent_id"
)
if "run_id" in effective_filters:
effective_filters["run_id"] = _validate_and_trim_entity_id(
effective_filters["run_id"], "run_id"
)
# Validate filters contains at least one entity ID
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError(
"filters must contain at least one of: user_id, agent_id, run_id. "
"Example: filters={'user_id': 'u1'}"
)
limit = top_k
# Apply enhanced metadata filtering if advanced operators are detected
if self._has_advanced_operators(effective_filters):
processed_filters = self._process_metadata_filters(effective_filters)
# Remove logical/operator keys that have been reprocessed
for logical_key in ("AND", "OR", "NOT"):
effective_filters.pop(logical_key, None)
for fk in list(effective_filters.keys()):
if fk not in ("AND", "OR", "NOT", "user_id", "agent_id", "run_id") and isinstance(effective_filters.get(fk), dict):
effective_filters.pop(fk, None)
effective_filters.update(processed_filters)
keys, encoded_ids = process_telemetry_filters(effective_filters)
capture_event(
"durag.search",
self,
{
"limit": limit,
"version": self.api_version,
"keys": keys,
"encoded_ids": encoded_ids,
"sync_type": "async",
"threshold": threshold,
"advanced_filters": bool(filters and self._has_advanced_operators(filters)),
},
)
original_memories = await self._search_vector_store(query, effective_filters, limit, threshold)
# Apply reranking if enabled and reranker is available
if rerank and self.reranker and original_memories:
try:
# Run reranking in thread pool to avoid blocking async loop
reranked_memories = await asyncio.to_thread(
self.reranker.rerank, query, original_memories, limit
)
original_memories = reranked_memories
except Exception as e:
logger.warning(f"Reranking failed, using original results: {e}")
return {"results": original_memories}
def _process_metadata_filters(self, metadata_filters: Dict[str, Any]) -> Dict[str, Any]:
"""
Process enhanced metadata filters and convert them to vector store compatible format.
Args:
metadata_filters: Enhanced metadata filters with operators
Returns:
Dict of processed filters compatible with vector store
"""
processed_filters = {}
def process_condition(key: str, condition: Any) -> Dict[str, Any]:
if not isinstance(condition, dict):
# Simple equality: {"key": "value"}
if condition == "*":
# Wildcard: match everything for this field (implementation depends on vector store)
return {key: "*"}
return {key: condition}
result = {}
for operator, value in condition.items():
# Map platform operators to universal format that can be translated by each vector store
operator_map = {
"eq": "eq", "ne": "ne", "gt": "gt", "gte": "gte",
"lt": "lt", "lte": "lte", "in": "in", "nin": "nin",
"contains": "contains", "icontains": "icontains"
}
if operator in operator_map:
result.setdefault(key, {})[operator_map[operator]] = value
else:
raise ValueError(f"Unsupported metadata filter operator: {operator}")
return result
def merge_filters(target: Dict[str, Any], source: Dict[str, Any]) -> None:
"""Merge source into target, deep-merging nested operator dicts for the same key."""
for key, value in source.items():
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
target[key].update(value)
else:
target[key] = value
for key, value in metadata_filters.items():
if key == "AND":
# Logical AND: combine multiple conditions
if not isinstance(value, list):
raise ValueError("AND operator requires a list of conditions")
for condition in value:
for sub_key, sub_value in condition.items():
merge_filters(processed_filters, process_condition(sub_key, sub_value))
elif key == "OR":
# Logical OR: Pass through to vector store for implementation-specific handling
if not isinstance(value, list) or not value:
raise ValueError("OR operator requires a non-empty list of conditions")
# Store OR conditions in a way that vector stores can interpret
processed_filters["$or"] = []
for condition in value:
or_condition = {}
for sub_key, sub_value in condition.items():
merge_filters(or_condition, process_condition(sub_key, sub_value))
processed_filters["$or"].append(or_condition)
elif key == "NOT":
# Logical NOT: Pass through to vector store for implementation-specific handling
if not isinstance(value, list) or not value:
raise ValueError("NOT operator requires a non-empty list of conditions")
processed_filters["$not"] = []
for condition in value:
not_condition = {}
for sub_key, sub_value in condition.items():
merge_filters(not_condition, process_condition(sub_key, sub_value))
processed_filters["$not"].append(not_condition)
else:
merge_filters(processed_filters, process_condition(key, value))
return processed_filters
def _has_advanced_operators(self, filters: Dict[str, Any]) -> bool:
"""
Check if filters contain advanced operators that need special processing.
Args:
filters: Dictionary of filters to check
Returns:
bool: True if advanced operators are detected
"""
if not isinstance(filters, dict):
return False
for key, value in filters.items():
# Check for platform-style logical operators
if key in ["AND", "OR", "NOT"]:
return True
# Check for comparison operators (without $ prefix for universal compatibility)
if isinstance(value, dict):
for op in value.keys():
if op in ["eq", "ne", "gt", "gte", "lt", "lte", "in", "nin", "contains", "icontains"]:
return True
# Check for wildcard values
if value == "*":
return True
return False
async def _search_vector_store(self, query, filters, limit, threshold=0.1):
if threshold is None:
threshold = 0.1
# Step 1: Preprocess query (CPU-bound)
query_lemmatized = await asyncio.to_thread(lemmatize_for_bm25, query)
query_entities = await asyncio.to_thread(extract_entities, query)
# Step 2: Embed query
embeddings = await asyncio.to_thread(self.embedding_model.embed, query, "search")
# Step 3: Semantic search (over-fetch)
internal_limit = max(limit * 4, 60)
semantic_results = await asyncio.to_thread(
self.vector_store.search, query=query, vectors=embeddings, top_k=internal_limit, filters=filters
)
# Step 4: Keyword search (if store supports it)
keyword_results = await asyncio.to_thread(
self.vector_store.keyword_search, query=query_lemmatized, top_k=internal_limit, filters=filters
)
# Step 5: Compute BM25 scores
bm25_scores = {}
if keyword_results is not None:
midpoint, steepness = get_bm25_params(query, lemmatized=query_lemmatized)
for mem in keyword_results:
mem_id = str(mem.id) if hasattr(mem, 'id') else str(mem.get('id', ''))
raw_score = mem.score if hasattr(mem, 'score') else mem.get('score', 0)
if raw_score and raw_score > 0:
bm25_scores[mem_id] = normalize_bm25(raw_score, midpoint, steepness)
# Step 6: Compute entity boosts
entity_boosts = {}
if query_entities:
entity_boosts = await self._compute_entity_boosts_async(query_entities, filters)
# Step 7: Build candidate set from semantic results
candidates = []
for mem in semantic_results:
mem_id = str(mem.id)
candidates.append({
"id": mem_id,
"score": mem.score,
"payload": mem.payload if hasattr(mem, 'payload') else {},
})
# Step 8: Score and rank
scored_results = score_and_rank(
semantic_results=candidates,
bm25_scores=bm25_scores,
entity_boosts=entity_boosts,
threshold=threshold,
top_k=limit,
)
# Step 9: Format results
promoted_payload_keys = [
"user_id",
"agent_id",
"run_id",
"actor_id",
"role",
]
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", "text_lemmatized", "attributed_to", *promoted_payload_keys}
original_memories = []
for scored in scored_results:
payload = scored.get("payload") or {}
if not payload.get("data"):
continue
memory_item_dict = MemoryItem(
id=scored["id"],
memory=payload.get("data", ""),
hash=payload.get("hash"),
created_at=payload.get("created_at"),
updated_at=payload.get("updated_at"),
score=scored["score"],
).model_dump()
for key in promoted_payload_keys:
if key in payload:
memory_item_dict[key] = payload[key]
additional_metadata = {k: v for k, v in payload.items() if k not in core_and_promoted_keys}
if additional_metadata:
if not memory_item_dict.get("metadata"):
memory_item_dict["metadata"] = {}
memory_item_dict["metadata"].update(additional_metadata)
original_memories.append(memory_item_dict)
return original_memories
async def _compute_entity_boosts_async(self, query_entities, filters):
"""Async version of entity boost computation."""
seen = set()
deduped = []
for entity_type, entity_text in query_entities[:8]:
key = entity_text.strip().lower()
if key and key not in seen:
seen.add(key)
deduped.append((entity_type, entity_text))
if not deduped:
return {}
search_filters = {k: v for k, v in filters.items() if k in ("user_id", "agent_id", "run_id") and v}
memory_boosts = {}
try:
for _, entity_text in deduped:
entity_embedding = await asyncio.to_thread(self.embedding_model.embed, entity_text, "search")
matches = await asyncio.to_thread(
self.entity_store.search,
query=entity_text,
vectors=entity_embedding,
top_k=500,
filters=search_filters,
)
for match in matches:
similarity = match.score if hasattr(match, 'score') else 0.0
if similarity < 0.5:
continue
payload = match.payload if hasattr(match, 'payload') else {}
linked_memory_ids = payload.get("linked_memory_ids", [])
if not isinstance(linked_memory_ids, list):
continue
num_linked = max(len(linked_memory_ids), 1)
memory_count_weight = 1.0 / (1.0 + 0.001 * ((num_linked - 1) ** 2))
boost = similarity * ENTITY_BOOST_WEIGHT * memory_count_weight
for memory_id in linked_memory_ids:
if memory_id:
memory_key = str(memory_id)
memory_boosts[memory_key] = max(memory_boosts.get(memory_key, 0.0), boost)
except Exception as e:
logger.warning(f"Entity boost computation failed: {e}")
return memory_boosts
async def update(self, memory_id, data, metadata: Optional[Dict[str, Any]] = None):
"""
Update a memory by ID asynchronously.
Args:
memory_id (str): ID of the memory to update.
data (str): New content to update the memory with.
metadata (dict, optional): Metadata to update with the memory. Defaults to None.
Returns:
dict: Success message indicating the memory was updated.
Example:
>>> await m.update(memory_id="mem_123", data="Likes to play tennis on weekends")
{'message': 'Memory updated successfully!'}
"""
capture_event("durag.update", self, {"memory_id": memory_id, "sync_type": "async"})
embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update")
existing_embeddings = {data: embeddings}
await self._update_memory(memory_id, data, existing_embeddings, metadata)
return {"message": "Memory updated successfully!"}
async def delete(self, memory_id):
"""
Delete a memory by ID asynchronously.
Args:
memory_id (str): ID of the memory to delete.
"""
capture_event("durag.delete", self, {"memory_id": memory_id, "sync_type": "async"})
existing_memory = await asyncio.to_thread(self.vector_store.get, vector_id=memory_id)
if existing_memory is None:
raise ValueError(f"Memory with id {memory_id} not found")
await self._delete_memory(memory_id, existing_memory)
return {"message": "Memory deleted successfully!"}
async def delete_all(self, user_id=None, agent_id=None, run_id=None):
"""
Delete all memories asynchronously.
Args:
user_id (str, optional): ID of the user to delete memories for. Defaults to None.
agent_id (str, optional): ID of the agent to delete memories for. Defaults to None.
run_id (str, optional): ID of the run to delete memories for. Defaults to None.
"""
filters = {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
if not filters:
raise ValueError(
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
)
keys, encoded_ids = process_telemetry_filters(filters)
capture_event("durag.delete_all", self, {"keys": keys, "encoded_ids": encoded_ids, "sync_type": "async"})
memories = await asyncio.to_thread(self.vector_store.list, filters=filters)
delete_tasks = []
for memory in memories[0]:
delete_tasks.append(self._delete_memory(memory.id))
await asyncio.gather(*delete_tasks)
logger.info(f"Deleted {len(memories[0])} memories")
return {"message": "Memories deleted successfully!"}
async def history(self, memory_id):
"""
Get the history of changes for a memory by ID asynchronously.
Args:
memory_id (str): ID of the memory to get history for.
Returns:
list: List of changes for the memory.
"""
capture_event("durag.history", self, {"memory_id": memory_id, "sync_type": "async"})
return await asyncio.to_thread(self.db.get_history, memory_id)
async def _create_memory(self, data, existing_embeddings, metadata=None):
logger.debug(f"Creating memory with {data=}")
if data in existing_embeddings:
embeddings = existing_embeddings[data]
else:
embeddings = await asyncio.to_thread(self.embedding_model.embed, data, memory_action="add")
memory_id = str(uuid.uuid4())
new_metadata = deepcopy(metadata) if metadata is not None else {}
new_metadata["data"] = data
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
if "created_at" not in new_metadata:
new_metadata["created_at"] = datetime.now(timezone.utc).isoformat()
new_metadata["updated_at"] = new_metadata["created_at"]
new_metadata["text_lemmatized"] = lemmatize_for_bm25(data)
await asyncio.to_thread(
self.vector_store.insert,
vectors=[embeddings],
ids=[memory_id],
payloads=[new_metadata],
)
await asyncio.to_thread(
self.db.add_history,
memory_id,
None,
data,
"ADD",
created_at=new_metadata.get("created_at"),
updated_at=new_metadata.get("updated_at"),
actor_id=new_metadata.get("actor_id"),
role=new_metadata.get("role"),
)
return memory_id
async def _create_procedural_memory(self, messages, metadata=None, llm=None, prompt=None):
"""
Create a procedural memory asynchronously
Args:
messages (list): List of messages to create a procedural memory from.
metadata (dict): Metadata to create a procedural memory from.
llm (llm, optional): LLM to use for the procedural memory creation. Defaults to None.
prompt (str, optional): Prompt to use for the procedural memory creation. Defaults to None.
"""
try:
from langchain_core.messages.utils import (
convert_to_messages, # type: ignore
)
except Exception:
logger.error(
"Import error while loading langchain-core. Please install 'langchain-core' to use procedural memory."
)
raise
logger.info("Creating procedural memory")
parsed_messages = [
{"role": "system", "content": prompt or PROCEDURAL_MEMORY_SYSTEM_PROMPT},
*messages,
{"role": "user", "content": "Create procedural memory of the above conversation."},
]
try:
if llm is not None:
parsed_messages = convert_to_messages(parsed_messages)
response = await asyncio.to_thread(llm.invoke, input=parsed_messages)
procedural_memory = response.content
else:
procedural_memory = await asyncio.to_thread(self.llm.generate_response, messages=parsed_messages)
procedural_memory = remove_code_blocks(procedural_memory)
except Exception as e:
logger.error(f"Error generating procedural memory summary: {e}")
raise
if metadata is None:
raise ValueError("Metadata cannot be done for procedural memory.")
metadata = {**metadata, "memory_type": MemoryType.PROCEDURAL.value}
embeddings = await asyncio.to_thread(self.embedding_model.embed, procedural_memory, memory_action="add")
memory_id = await self._create_memory(procedural_memory, {procedural_memory: embeddings}, metadata=metadata)
capture_event("durag._create_procedural_memory", self, {"memory_id": memory_id, "sync_type": "async"})
result = {"results": [{"id": memory_id, "memory": procedural_memory, "event": "ADD"}]}
return result
async def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
logger.info(f"Updating memory with {data=}")
try:
existing_memory = await asyncio.to_thread(self.vector_store.get, vector_id=memory_id)
except Exception:
logger.error(f"Error getting memory with ID {memory_id} during update.")
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
if existing_memory is None:
raise ValueError(f"Memory with id {memory_id} not found. Please provide a valid 'memory_id'")
prev_value = existing_memory.payload.get("data")
new_metadata = deepcopy(metadata) if metadata is not None else {}
new_metadata["data"] = data
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
new_metadata["text_lemmatized"] = lemmatize_for_bm25(data)
new_metadata["created_at"] = existing_memory.payload.get("created_at")
new_metadata["updated_at"] = datetime.now(timezone.utc).isoformat()
# Preserve session identifiers from existing memory only if not provided in new metadata
if "user_id" not in new_metadata and "user_id" in existing_memory.payload:
new_metadata["user_id"] = existing_memory.payload["user_id"]
if "agent_id" not in new_metadata and "agent_id" in existing_memory.payload:
new_metadata["agent_id"] = existing_memory.payload["agent_id"]
if "run_id" not in new_metadata and "run_id" in existing_memory.payload:
new_metadata["run_id"] = existing_memory.payload["run_id"]
if "actor_id" in existing_memory.payload:
new_metadata["actor_id"] = existing_memory.payload["actor_id"]
if "role" not in new_metadata and "role" in existing_memory.payload:
new_metadata["role"] = existing_memory.payload["role"]
if data in existing_embeddings:
embeddings = existing_embeddings[data]
else:
embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update")
await asyncio.to_thread(
self.vector_store.update,
vector_id=memory_id,
vector=embeddings,
payload=new_metadata,
)
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
await asyncio.to_thread(
self.db.add_history,
memory_id,
prev_value,
data,
"UPDATE",
created_at=new_metadata["created_at"],
updated_at=new_metadata["updated_at"],
actor_id=new_metadata.get("actor_id"),
role=new_metadata.get("role"),
)
# Entity-store cleanup: strip this memory's id from old-text entities,
# then re-extract entities from the new text and link them back.
session_filters = {k: new_metadata[k] for k in ("user_id", "agent_id", "run_id") if new_metadata.get(k)}
await self._remove_memory_from_entity_store(memory_id, session_filters)
await self._link_entities_for_memory(memory_id, data, session_filters)
return memory_id
async def _delete_memory(self, memory_id, existing_memory=None):
logger.info(f"Deleting memory with {memory_id=}")
if existing_memory is None:
existing_memory = await asyncio.to_thread(self.vector_store.get, vector_id=memory_id)
if existing_memory is None:
raise ValueError(f"Memory with id {memory_id} not found. Please provide a valid 'memory_id'")
prev_value = existing_memory.payload.get("data", "")
created_at = _normalize_iso_timestamp_to_utc(existing_memory.payload.get("created_at"))
updated_at = datetime.now(timezone.utc).isoformat()
payload = existing_memory.payload or {}
session_filters = {k: payload[k] for k in ("user_id", "agent_id", "run_id") if payload.get(k)}
await asyncio.to_thread(self.vector_store.delete, vector_id=memory_id)
await asyncio.to_thread(
self.db.add_history,
memory_id,
prev_value,
None,
"DELETE",
created_at=created_at,
updated_at=updated_at,
actor_id=existing_memory.payload.get("actor_id"),
role=existing_memory.payload.get("role"),
is_deleted=1,
)
# Entity-store cleanup: strip this memory's id from any entity records
# that linked to it. Non-fatal — the helper swallows errors.
await self._remove_memory_from_entity_store(memory_id, session_filters)
return memory_id
async def reset(self):
"""
Reset the memory store asynchronously by:
Deletes the vector store collection
Resets the database
Recreates the vector store with a new client
"""
logger.warning("Resetting all memories")
await asyncio.to_thread(self.vector_store.delete_col)
gc.collect()
if hasattr(self.vector_store, "client") and hasattr(self.vector_store.client, "close"):
await asyncio.to_thread(self.vector_store.client.close)
if hasattr(self.db, "connection") and self.db.connection:
await asyncio.to_thread(lambda: self.db.connection.execute("DROP TABLE IF EXISTS history"))
await asyncio.to_thread(self.db.connection.close)
self.db = SQLiteManager(self.config.history_db_path)
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
capture_event("durag.reset", self, {"sync_type": "async"})
def close(self):
"""Release resources held by this AsyncMemory instance."""
if hasattr(self, "db") and self.db is not None:
self.db.close()
self.db = None
async def chat(self, query):
raise NotImplementedError("Chat function not implemented yet.")