Ignore external conversation_id for device/user memory

This commit is contained in:
Your Name 2025-12-20 13:51:41 -06:00
parent b964c74a9e
commit 1d97288174

View file

@ -29,6 +29,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers import device_registry as dr, intent, llm, template
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid
@ -44,6 +45,10 @@ from .const import (
CONF_PRESENCE_PENALTY,
CONF_PROMPT,
CONF_RESPONSE_FORMAT,
CONF_SEARXNG_ENABLED,
CONF_SEARXNG_LANGUAGE,
CONF_SEARXNG_SAFESEARCH,
CONF_SEARXNG_URL,
CONF_SEED,
CONF_STOP,
CONF_TEMPERATURE,
@ -56,6 +61,10 @@ from .const import (
DEFAULT_PARALLEL_TOOL_CALLS,
DEFAULT_PRESENCE_PENALTY,
DEFAULT_RESPONSE_FORMAT,
DEFAULT_SEARXNG_ENABLED,
DEFAULT_SEARXNG_LANGUAGE,
DEFAULT_SEARXNG_SAFESEARCH,
DEFAULT_SEARXNG_URL,
DEFAULT_TEMPERATURE,
DEFAULT_TOOL_CHOICE,
DEFAULT_TOP_P,
@ -114,6 +123,57 @@ def _parse_stop_sequences(value: str | None) -> list[str] | None:
return [part for part in expanded if part]
def _searxng_tool() -> ChatCompletionToolParam:
tool_spec = FunctionDefinition(
name="web_search",
parameters={
"type": "object",
"properties": {
"query": {"type": "string"},
"limit": {"type": "integer", "default": 5},
"language": {"type": "string"},
"safesearch": {"type": "integer"},
},
"required": ["query"],
},
description="Search the web via searxng and return top results.",
)
return ChatCompletionToolParam(type="function", function=tool_spec)
async def _run_searxng(
hass: HomeAssistant,
options: dict[str, Any],
tool_args: dict[str, Any],
) -> dict[str, Any]:
base_url = options.get(CONF_SEARXNG_URL, DEFAULT_SEARXNG_URL).rstrip("/")
params = {
"q": tool_args.get("query", ""),
"format": "json",
"language": tool_args.get("language") or options.get(CONF_SEARXNG_LANGUAGE, DEFAULT_SEARXNG_LANGUAGE),
"safesearch": tool_args.get("safesearch", options.get(CONF_SEARXNG_SAFESEARCH, DEFAULT_SEARXNG_SAFESEARCH)),
}
limit = tool_args.get("limit", 5)
try:
limit = int(limit)
except (TypeError, ValueError):
limit = 5
session = async_get_clientsession(hass)
async with session.get(f"{base_url}/search", params=params, timeout=20) as resp:
resp.raise_for_status()
payload = await resp.json()
results = []
for item in payload.get("results", [])[:limit]:
results.append(
{
"title": item.get("title"),
"url": item.get("url"),
"content": item.get("content"),
}
)
return {"results": results}
class GroqdConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
@ -187,6 +247,10 @@ class GroqdConversationEntity(
response=intent_response, conversation_id=user_input.conversation_id
)
tools = [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools]
if options.get(CONF_SEARXNG_ENABLED, DEFAULT_SEARXNG_ENABLED):
if tools is None:
tools = []
tools.append(_searxng_tool())
memory_scope = options.get(CONF_MEMORY_SCOPE, DEFAULT_MEMORY_SCOPE)
memory_key = None
@ -209,6 +273,7 @@ class GroqdConversationEntity(
elif memory_scope == "global":
memory_key = "global"
if memory_scope == "conversation":
if user_input.conversation_id is not None:
conversation_id = user_input.conversation_id
history = self.history.get(conversation_id, [])
@ -218,6 +283,13 @@ class GroqdConversationEntity(
else:
conversation_id = ulid.ulid_now()
history = []
else:
if memory_key and memory_key in self._memory_index:
conversation_id = self._memory_index[memory_key]
history = self.history.get(conversation_id, [])
else:
conversation_id = ulid.ulid_now()
history = []
if (
user_input.context
@ -345,13 +417,25 @@ class GroqdConversationEntity(
messages.append(message_convert(response))
tool_calls = response.tool_calls
if not tool_calls or not llm_api:
if not tool_calls:
break
for tool_call in tool_calls:
tool_name = tool_call.function.name
try:
tool_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
tool_args = {}
if tool_name == "web_search":
try:
tool_response = await _run_searxng(self.hass, options, tool_args)
except Exception as err:
tool_response = {"error": type(err).__name__, "error_text": str(err)}
elif llm_api:
tool_input = llm.ToolInput(
tool_name=tool_call.function.name,
tool_args=json.loads(tool_call.function.arguments),
tool_name=tool_name,
tool_args=tool_args,
)
try:
tool_response = await llm_api.async_call_tool(tool_input)
@ -359,6 +443,8 @@ class GroqdConversationEntity(
tool_response = {"error": type(err).__name__}
if str(err):
tool_response["error_text"] = str(err)
else:
tool_response = {"error": "ToolNotAvailable", "error_text": tool_name}
messages.append(
ChatCompletionToolMessageParam(