mirror of
https://github.com/sudoxreboot/groqd
synced 2026-04-14 11:36:49 +00:00
Ignore external conversation_id for device/user memory
This commit is contained in:
parent
b964c74a9e
commit
1d97288174
1 changed files with 104 additions and 18 deletions
|
|
@ -29,6 +29,7 @@ from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||||
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||||
from homeassistant.helpers import device_registry as dr, intent, llm, template
|
from homeassistant.helpers import device_registry as dr, intent, llm, template
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.util import ulid
|
from homeassistant.util import ulid
|
||||||
|
|
@ -44,6 +45,10 @@ from .const import (
|
||||||
CONF_PRESENCE_PENALTY,
|
CONF_PRESENCE_PENALTY,
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
CONF_RESPONSE_FORMAT,
|
CONF_RESPONSE_FORMAT,
|
||||||
|
CONF_SEARXNG_ENABLED,
|
||||||
|
CONF_SEARXNG_LANGUAGE,
|
||||||
|
CONF_SEARXNG_SAFESEARCH,
|
||||||
|
CONF_SEARXNG_URL,
|
||||||
CONF_SEED,
|
CONF_SEED,
|
||||||
CONF_STOP,
|
CONF_STOP,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
|
|
@ -56,6 +61,10 @@ from .const import (
|
||||||
DEFAULT_PARALLEL_TOOL_CALLS,
|
DEFAULT_PARALLEL_TOOL_CALLS,
|
||||||
DEFAULT_PRESENCE_PENALTY,
|
DEFAULT_PRESENCE_PENALTY,
|
||||||
DEFAULT_RESPONSE_FORMAT,
|
DEFAULT_RESPONSE_FORMAT,
|
||||||
|
DEFAULT_SEARXNG_ENABLED,
|
||||||
|
DEFAULT_SEARXNG_LANGUAGE,
|
||||||
|
DEFAULT_SEARXNG_SAFESEARCH,
|
||||||
|
DEFAULT_SEARXNG_URL,
|
||||||
DEFAULT_TEMPERATURE,
|
DEFAULT_TEMPERATURE,
|
||||||
DEFAULT_TOOL_CHOICE,
|
DEFAULT_TOOL_CHOICE,
|
||||||
DEFAULT_TOP_P,
|
DEFAULT_TOP_P,
|
||||||
|
|
@ -114,6 +123,57 @@ def _parse_stop_sequences(value: str | None) -> list[str] | None:
|
||||||
return [part for part in expanded if part]
|
return [part for part in expanded if part]
|
||||||
|
|
||||||
|
|
||||||
|
def _searxng_tool() -> ChatCompletionToolParam:
|
||||||
|
tool_spec = FunctionDefinition(
|
||||||
|
name="web_search",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string"},
|
||||||
|
"limit": {"type": "integer", "default": 5},
|
||||||
|
"language": {"type": "string"},
|
||||||
|
"safesearch": {"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
description="Search the web via searxng and return top results.",
|
||||||
|
)
|
||||||
|
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_searxng(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
options: dict[str, Any],
|
||||||
|
tool_args: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
base_url = options.get(CONF_SEARXNG_URL, DEFAULT_SEARXNG_URL).rstrip("/")
|
||||||
|
params = {
|
||||||
|
"q": tool_args.get("query", ""),
|
||||||
|
"format": "json",
|
||||||
|
"language": tool_args.get("language") or options.get(CONF_SEARXNG_LANGUAGE, DEFAULT_SEARXNG_LANGUAGE),
|
||||||
|
"safesearch": tool_args.get("safesearch", options.get(CONF_SEARXNG_SAFESEARCH, DEFAULT_SEARXNG_SAFESEARCH)),
|
||||||
|
}
|
||||||
|
limit = tool_args.get("limit", 5)
|
||||||
|
try:
|
||||||
|
limit = int(limit)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
limit = 5
|
||||||
|
session = async_get_clientsession(hass)
|
||||||
|
async with session.get(f"{base_url}/search", params=params, timeout=20) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
payload = await resp.json()
|
||||||
|
results = []
|
||||||
|
for item in payload.get("results", [])[:limit]:
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"title": item.get("title"),
|
||||||
|
"url": item.get("url"),
|
||||||
|
"content": item.get("content"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return {"results": results}
|
||||||
|
|
||||||
|
|
||||||
class GroqdConversationEntity(
|
class GroqdConversationEntity(
|
||||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||||
):
|
):
|
||||||
|
|
@ -187,6 +247,10 @@ class GroqdConversationEntity(
|
||||||
response=intent_response, conversation_id=user_input.conversation_id
|
response=intent_response, conversation_id=user_input.conversation_id
|
||||||
)
|
)
|
||||||
tools = [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools]
|
tools = [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools]
|
||||||
|
if options.get(CONF_SEARXNG_ENABLED, DEFAULT_SEARXNG_ENABLED):
|
||||||
|
if tools is None:
|
||||||
|
tools = []
|
||||||
|
tools.append(_searxng_tool())
|
||||||
|
|
||||||
memory_scope = options.get(CONF_MEMORY_SCOPE, DEFAULT_MEMORY_SCOPE)
|
memory_scope = options.get(CONF_MEMORY_SCOPE, DEFAULT_MEMORY_SCOPE)
|
||||||
memory_key = None
|
memory_key = None
|
||||||
|
|
@ -209,15 +273,23 @@ class GroqdConversationEntity(
|
||||||
elif memory_scope == "global":
|
elif memory_scope == "global":
|
||||||
memory_key = "global"
|
memory_key = "global"
|
||||||
|
|
||||||
if user_input.conversation_id is not None:
|
if memory_scope == "conversation":
|
||||||
conversation_id = user_input.conversation_id
|
if user_input.conversation_id is not None:
|
||||||
history = self.history.get(conversation_id, [])
|
conversation_id = user_input.conversation_id
|
||||||
elif memory_key and memory_key in self._memory_index:
|
history = self.history.get(conversation_id, [])
|
||||||
conversation_id = self._memory_index[memory_key]
|
elif memory_key and memory_key in self._memory_index:
|
||||||
history = self.history.get(conversation_id, [])
|
conversation_id = self._memory_index[memory_key]
|
||||||
|
history = self.history.get(conversation_id, [])
|
||||||
|
else:
|
||||||
|
conversation_id = ulid.ulid_now()
|
||||||
|
history = []
|
||||||
else:
|
else:
|
||||||
conversation_id = ulid.ulid_now()
|
if memory_key and memory_key in self._memory_index:
|
||||||
history = []
|
conversation_id = self._memory_index[memory_key]
|
||||||
|
history = self.history.get(conversation_id, [])
|
||||||
|
else:
|
||||||
|
conversation_id = ulid.ulid_now()
|
||||||
|
history = []
|
||||||
|
|
||||||
if (
|
if (
|
||||||
user_input.context
|
user_input.context
|
||||||
|
|
@ -345,20 +417,34 @@ class GroqdConversationEntity(
|
||||||
messages.append(message_convert(response))
|
messages.append(message_convert(response))
|
||||||
tool_calls = response.tool_calls
|
tool_calls = response.tool_calls
|
||||||
|
|
||||||
if not tool_calls or not llm_api:
|
if not tool_calls:
|
||||||
break
|
break
|
||||||
|
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
tool_input = llm.ToolInput(
|
tool_name = tool_call.function.name
|
||||||
tool_name=tool_call.function.name,
|
|
||||||
tool_args=json.loads(tool_call.function.arguments),
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
tool_response = await llm_api.async_call_tool(tool_input)
|
tool_args = json.loads(tool_call.function.arguments)
|
||||||
except (HomeAssistantError, vol.Invalid) as err:
|
except json.JSONDecodeError:
|
||||||
tool_response = {"error": type(err).__name__}
|
tool_args = {}
|
||||||
if str(err):
|
|
||||||
tool_response["error_text"] = str(err)
|
if tool_name == "web_search":
|
||||||
|
try:
|
||||||
|
tool_response = await _run_searxng(self.hass, options, tool_args)
|
||||||
|
except Exception as err:
|
||||||
|
tool_response = {"error": type(err).__name__, "error_text": str(err)}
|
||||||
|
elif llm_api:
|
||||||
|
tool_input = llm.ToolInput(
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_args=tool_args,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
tool_response = await llm_api.async_call_tool(tool_input)
|
||||||
|
except (HomeAssistantError, vol.Invalid) as err:
|
||||||
|
tool_response = {"error": type(err).__name__}
|
||||||
|
if str(err):
|
||||||
|
tool_response["error_text"] = str(err)
|
||||||
|
else:
|
||||||
|
tool_response = {"error": "ToolNotAvailable", "error_text": tool_name}
|
||||||
|
|
||||||
messages.append(
|
messages.append(
|
||||||
ChatCompletionToolMessageParam(
|
ChatCompletionToolMessageParam(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue