mirror of
https://github.com/sudoxreboot/groqd
synced 2026-04-14 03:26:35 +00:00
Ignore external conversation_id for device/user memory
This commit is contained in:
parent
b964c74a9e
commit
1d97288174
1 changed files with 104 additions and 18 deletions
|
|
@ -29,6 +29,7 @@ from homeassistant.config_entries import ConfigEntry
|
|||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers import device_registry as dr, intent, llm, template
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.util import ulid
|
||||
|
|
@ -44,6 +45,10 @@ from .const import (
|
|||
CONF_PRESENCE_PENALTY,
|
||||
CONF_PROMPT,
|
||||
CONF_RESPONSE_FORMAT,
|
||||
CONF_SEARXNG_ENABLED,
|
||||
CONF_SEARXNG_LANGUAGE,
|
||||
CONF_SEARXNG_SAFESEARCH,
|
||||
CONF_SEARXNG_URL,
|
||||
CONF_SEED,
|
||||
CONF_STOP,
|
||||
CONF_TEMPERATURE,
|
||||
|
|
@ -56,6 +61,10 @@ from .const import (
|
|||
DEFAULT_PARALLEL_TOOL_CALLS,
|
||||
DEFAULT_PRESENCE_PENALTY,
|
||||
DEFAULT_RESPONSE_FORMAT,
|
||||
DEFAULT_SEARXNG_ENABLED,
|
||||
DEFAULT_SEARXNG_LANGUAGE,
|
||||
DEFAULT_SEARXNG_SAFESEARCH,
|
||||
DEFAULT_SEARXNG_URL,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOOL_CHOICE,
|
||||
DEFAULT_TOP_P,
|
||||
|
|
@ -114,6 +123,57 @@ def _parse_stop_sequences(value: str | None) -> list[str] | None:
|
|||
return [part for part in expanded if part]
|
||||
|
||||
|
||||
def _searxng_tool() -> ChatCompletionToolParam:
|
||||
tool_spec = FunctionDefinition(
|
||||
name="web_search",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"limit": {"type": "integer", "default": 5},
|
||||
"language": {"type": "string"},
|
||||
"safesearch": {"type": "integer"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
description="Search the web via searxng and return top results.",
|
||||
)
|
||||
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||
|
||||
|
||||
async def _run_searxng(
|
||||
hass: HomeAssistant,
|
||||
options: dict[str, Any],
|
||||
tool_args: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
base_url = options.get(CONF_SEARXNG_URL, DEFAULT_SEARXNG_URL).rstrip("/")
|
||||
params = {
|
||||
"q": tool_args.get("query", ""),
|
||||
"format": "json",
|
||||
"language": tool_args.get("language") or options.get(CONF_SEARXNG_LANGUAGE, DEFAULT_SEARXNG_LANGUAGE),
|
||||
"safesearch": tool_args.get("safesearch", options.get(CONF_SEARXNG_SAFESEARCH, DEFAULT_SEARXNG_SAFESEARCH)),
|
||||
}
|
||||
limit = tool_args.get("limit", 5)
|
||||
try:
|
||||
limit = int(limit)
|
||||
except (TypeError, ValueError):
|
||||
limit = 5
|
||||
session = async_get_clientsession(hass)
|
||||
async with session.get(f"{base_url}/search", params=params, timeout=20) as resp:
|
||||
resp.raise_for_status()
|
||||
payload = await resp.json()
|
||||
results = []
|
||||
for item in payload.get("results", [])[:limit]:
|
||||
results.append(
|
||||
{
|
||||
"title": item.get("title"),
|
||||
"url": item.get("url"),
|
||||
"content": item.get("content"),
|
||||
}
|
||||
)
|
||||
return {"results": results}
|
||||
|
||||
|
||||
class GroqdConversationEntity(
|
||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||
):
|
||||
|
|
@ -187,6 +247,10 @@ class GroqdConversationEntity(
|
|||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
tools = [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools]
|
||||
if options.get(CONF_SEARXNG_ENABLED, DEFAULT_SEARXNG_ENABLED):
|
||||
if tools is None:
|
||||
tools = []
|
||||
tools.append(_searxng_tool())
|
||||
|
||||
memory_scope = options.get(CONF_MEMORY_SCOPE, DEFAULT_MEMORY_SCOPE)
|
||||
memory_key = None
|
||||
|
|
@ -209,15 +273,23 @@ class GroqdConversationEntity(
|
|||
elif memory_scope == "global":
|
||||
memory_key = "global"
|
||||
|
||||
if user_input.conversation_id is not None:
|
||||
conversation_id = user_input.conversation_id
|
||||
history = self.history.get(conversation_id, [])
|
||||
elif memory_key and memory_key in self._memory_index:
|
||||
conversation_id = self._memory_index[memory_key]
|
||||
history = self.history.get(conversation_id, [])
|
||||
if memory_scope == "conversation":
|
||||
if user_input.conversation_id is not None:
|
||||
conversation_id = user_input.conversation_id
|
||||
history = self.history.get(conversation_id, [])
|
||||
elif memory_key and memory_key in self._memory_index:
|
||||
conversation_id = self._memory_index[memory_key]
|
||||
history = self.history.get(conversation_id, [])
|
||||
else:
|
||||
conversation_id = ulid.ulid_now()
|
||||
history = []
|
||||
else:
|
||||
conversation_id = ulid.ulid_now()
|
||||
history = []
|
||||
if memory_key and memory_key in self._memory_index:
|
||||
conversation_id = self._memory_index[memory_key]
|
||||
history = self.history.get(conversation_id, [])
|
||||
else:
|
||||
conversation_id = ulid.ulid_now()
|
||||
history = []
|
||||
|
||||
if (
|
||||
user_input.context
|
||||
|
|
@ -345,20 +417,34 @@ class GroqdConversationEntity(
|
|||
messages.append(message_convert(response))
|
||||
tool_calls = response.tool_calls
|
||||
|
||||
if not tool_calls or not llm_api:
|
||||
if not tool_calls:
|
||||
break
|
||||
|
||||
for tool_call in tool_calls:
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=tool_call.function.name,
|
||||
tool_args=json.loads(tool_call.function.arguments),
|
||||
)
|
||||
tool_name = tool_call.function.name
|
||||
try:
|
||||
tool_response = await llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as err:
|
||||
tool_response = {"error": type(err).__name__}
|
||||
if str(err):
|
||||
tool_response["error_text"] = str(err)
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
tool_args = {}
|
||||
|
||||
if tool_name == "web_search":
|
||||
try:
|
||||
tool_response = await _run_searxng(self.hass, options, tool_args)
|
||||
except Exception as err:
|
||||
tool_response = {"error": type(err).__name__, "error_text": str(err)}
|
||||
elif llm_api:
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=tool_name,
|
||||
tool_args=tool_args,
|
||||
)
|
||||
try:
|
||||
tool_response = await llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as err:
|
||||
tool_response = {"error": type(err).__name__}
|
||||
if str(err):
|
||||
tool_response["error_text"] = str(err)
|
||||
else:
|
||||
tool_response = {"error": "ToolNotAvailable", "error_text": tool_name}
|
||||
|
||||
messages.append(
|
||||
ChatCompletionToolMessageParam(
|
||||
|
|
|
|||
Loading…
Reference in a new issue