diff --git a/custom_components/groqd/conversation.py b/custom_components/groqd/conversation.py index 9b311c3..2b8d270 100644 --- a/custom_components/groqd/conversation.py +++ b/custom_components/groqd/conversation.py @@ -29,6 +29,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError, TemplateError +from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers import device_registry as dr, intent, llm, template from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.util import ulid @@ -44,6 +45,10 @@ from .const import ( CONF_PRESENCE_PENALTY, CONF_PROMPT, CONF_RESPONSE_FORMAT, + CONF_SEARXNG_ENABLED, + CONF_SEARXNG_LANGUAGE, + CONF_SEARXNG_SAFESEARCH, + CONF_SEARXNG_URL, CONF_SEED, CONF_STOP, CONF_TEMPERATURE, @@ -56,6 +61,10 @@ from .const import ( DEFAULT_PARALLEL_TOOL_CALLS, DEFAULT_PRESENCE_PENALTY, DEFAULT_RESPONSE_FORMAT, + DEFAULT_SEARXNG_ENABLED, + DEFAULT_SEARXNG_LANGUAGE, + DEFAULT_SEARXNG_SAFESEARCH, + DEFAULT_SEARXNG_URL, DEFAULT_TEMPERATURE, DEFAULT_TOOL_CHOICE, DEFAULT_TOP_P, @@ -114,6 +123,57 @@ def _parse_stop_sequences(value: str | None) -> list[str] | None: return [part for part in expanded if part] +def _searxng_tool() -> ChatCompletionToolParam: + tool_spec = FunctionDefinition( + name="web_search", + parameters={ + "type": "object", + "properties": { + "query": {"type": "string"}, + "limit": {"type": "integer", "default": 5}, + "language": {"type": "string"}, + "safesearch": {"type": "integer"}, + }, + "required": ["query"], + }, + description="Search the web via searxng and return top results.", + ) + return ChatCompletionToolParam(type="function", function=tool_spec) + + +async def _run_searxng( + hass: HomeAssistant, + options: dict[str, Any], + tool_args: dict[str, Any], +) -> dict[str, Any]: + base_url = options.get(CONF_SEARXNG_URL, DEFAULT_SEARXNG_URL).rstrip("/") + params = { + "q": tool_args.get("query", ""), + "format": "json", + "language": tool_args.get("language") or options.get(CONF_SEARXNG_LANGUAGE, DEFAULT_SEARXNG_LANGUAGE), + "safesearch": tool_args.get("safesearch", options.get(CONF_SEARXNG_SAFESEARCH, DEFAULT_SEARXNG_SAFESEARCH)), + } + limit = tool_args.get("limit", 5) + try: + limit = int(limit) + except (TypeError, ValueError): + limit = 5 + session = async_get_clientsession(hass) + async with session.get(f"{base_url}/search", params=params, timeout=20) as resp: + resp.raise_for_status() + payload = await resp.json() + results = [] + for item in payload.get("results", [])[:limit]: + results.append( + { + "title": item.get("title"), + "url": item.get("url"), + "content": item.get("content"), + } + ) + return {"results": results} + + class GroqdConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): @@ -187,6 +247,10 @@ class GroqdConversationEntity( response=intent_response, conversation_id=user_input.conversation_id ) tools = [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools] + if options.get(CONF_SEARXNG_ENABLED, DEFAULT_SEARXNG_ENABLED): + if tools is None: + tools = [] + tools.append(_searxng_tool()) memory_scope = options.get(CONF_MEMORY_SCOPE, DEFAULT_MEMORY_SCOPE) memory_key = None @@ -209,15 +273,23 @@ class GroqdConversationEntity( elif memory_scope == "global": memory_key = "global" - if user_input.conversation_id is not None: - conversation_id = user_input.conversation_id - history = self.history.get(conversation_id, []) - elif memory_key and memory_key in self._memory_index: - conversation_id = self._memory_index[memory_key] - history = self.history.get(conversation_id, []) + if memory_scope == "conversation": + if user_input.conversation_id is not None: + conversation_id = user_input.conversation_id + history = self.history.get(conversation_id, []) + elif memory_key and memory_key in self._memory_index: + conversation_id = self._memory_index[memory_key] + history = self.history.get(conversation_id, []) + else: + conversation_id = ulid.ulid_now() + history = [] else: - conversation_id = ulid.ulid_now() - history = [] + if memory_key and memory_key in self._memory_index: + conversation_id = self._memory_index[memory_key] + history = self.history.get(conversation_id, []) + else: + conversation_id = ulid.ulid_now() + history = [] if ( user_input.context @@ -345,20 +417,34 @@ class GroqdConversationEntity( messages.append(message_convert(response)) tool_calls = response.tool_calls - if not tool_calls or not llm_api: + if not tool_calls: break for tool_call in tool_calls: - tool_input = llm.ToolInput( - tool_name=tool_call.function.name, - tool_args=json.loads(tool_call.function.arguments), - ) + tool_name = tool_call.function.name try: - tool_response = await llm_api.async_call_tool(tool_input) - except (HomeAssistantError, vol.Invalid) as err: - tool_response = {"error": type(err).__name__} - if str(err): - tool_response["error_text"] = str(err) + tool_args = json.loads(tool_call.function.arguments) + except json.JSONDecodeError: + tool_args = {} + + if tool_name == "web_search": + try: + tool_response = await _run_searxng(self.hass, options, tool_args) + except Exception as err: + tool_response = {"error": type(err).__name__, "error_text": str(err)} + elif llm_api: + tool_input = llm.ToolInput( + tool_name=tool_name, + tool_args=tool_args, + ) + try: + tool_response = await llm_api.async_call_tool(tool_input) + except (HomeAssistantError, vol.Invalid) as err: + tool_response = {"error": type(err).__name__} + if str(err): + tool_response["error_text"] = str(err) + else: + tool_response = {"error": "ToolNotAvailable", "error_text": tool_name} messages.append( ChatCompletionToolMessageParam(