mirror of
https://github.com/sudoxreboot/groqd
synced 2026-04-14 11:36:49 +00:00
509 lines
19 KiB
Python
509 lines
19 KiB
Python
"""Conversation support for groqd."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
import json
|
|
from typing import Any, Literal
|
|
|
|
from groq._types import NOT_GIVEN
|
|
from groq.types.chat import (
|
|
ChatCompletionAssistantMessageParam,
|
|
ChatCompletionMessage,
|
|
ChatCompletionMessageParam,
|
|
ChatCompletionMessageToolCallParam,
|
|
ChatCompletionSystemMessageParam,
|
|
ChatCompletionToolMessageParam,
|
|
ChatCompletionToolParam,
|
|
ChatCompletionUserMessageParam,
|
|
)
|
|
from groq.types.chat.chat_completion_message_tool_call_param import Function
|
|
from groq.types.shared_params import FunctionDefinition
|
|
import groq
|
|
import voluptuous as vol
|
|
from voluptuous_openapi import convert
|
|
|
|
from homeassistant.components import conversation
|
|
from homeassistant.components.conversation import trace
|
|
from homeassistant.config_entries import ConfigEntry
|
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
|
from homeassistant.helpers.storage import Store
|
|
from homeassistant.helpers import device_registry as dr, intent, llm, template
|
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
|
from homeassistant.util import ulid
|
|
|
|
from . import GroqdConfigEntry
|
|
from .const import (
|
|
CONF_CHAT_MODEL,
|
|
CONF_CONTEXT_MESSAGES,
|
|
CONF_FREQUENCY_PENALTY,
|
|
CONF_MEMORY_SCOPE,
|
|
CONF_MAX_TOKENS,
|
|
CONF_PARALLEL_TOOL_CALLS,
|
|
CONF_PRESENCE_PENALTY,
|
|
CONF_PROMPT,
|
|
CONF_RESPONSE_FORMAT,
|
|
CONF_SEARXNG_ENABLED,
|
|
CONF_SEARXNG_LANGUAGE,
|
|
CONF_SEARXNG_SAFESEARCH,
|
|
CONF_SEARXNG_URL,
|
|
CONF_SEED,
|
|
CONF_STOP,
|
|
CONF_TEMPERATURE,
|
|
CONF_TOOL_CHOICE,
|
|
CONF_TOP_P,
|
|
DEFAULT_CHAT_MODEL,
|
|
DEFAULT_CONTEXT_MESSAGES,
|
|
DEFAULT_FREQUENCY_PENALTY,
|
|
DEFAULT_MAX_TOKENS,
|
|
DEFAULT_PARALLEL_TOOL_CALLS,
|
|
DEFAULT_PRESENCE_PENALTY,
|
|
DEFAULT_RESPONSE_FORMAT,
|
|
DEFAULT_SEARXNG_ENABLED,
|
|
DEFAULT_SEARXNG_LANGUAGE,
|
|
DEFAULT_SEARXNG_SAFESEARCH,
|
|
DEFAULT_SEARXNG_URL,
|
|
DEFAULT_TEMPERATURE,
|
|
DEFAULT_TOOL_CHOICE,
|
|
DEFAULT_TOP_P,
|
|
DEFAULT_MEMORY_SCOPE,
|
|
DOMAIN,
|
|
LOGGER,
|
|
)
|
|
|
|
MAX_TOOL_ITERATIONS = 10
|
|
|
|
|
|
async def async_setup_entry(
|
|
hass: HomeAssistant,
|
|
config_entry: GroqdConfigEntry,
|
|
async_add_entities: AddEntitiesCallback,
|
|
) -> None:
|
|
"""Set up conversation entities."""
|
|
agent = GroqdConversationEntity(config_entry)
|
|
async_add_entities([agent])
|
|
|
|
|
|
def _format_tool(
|
|
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
|
) -> ChatCompletionToolParam:
|
|
"""Format tool specification."""
|
|
tool_spec = FunctionDefinition(
|
|
name=tool.name,
|
|
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
|
|
)
|
|
if tool.description:
|
|
tool_spec["description"] = tool.description
|
|
return ChatCompletionToolParam(type="function", function=tool_spec)
|
|
|
|
|
|
def _parse_tool_choice(value: str) -> Any:
|
|
"""Parse tool choice option into Groq-compatible value."""
|
|
value = (value or "").strip()
|
|
if not value:
|
|
return NOT_GIVEN
|
|
if value in {"auto", "none", "required"}:
|
|
return value
|
|
if value.startswith("tool:"):
|
|
name = value.split(":", 1)[1].strip()
|
|
if name:
|
|
return {"type": "function", "function": {"name": name}}
|
|
return value
|
|
|
|
|
|
def _parse_stop_sequences(value: str | None) -> list[str] | None:
|
|
if not value:
|
|
return None
|
|
parts = [part.strip() for part in value.replace("\r", "").split("\n")]
|
|
expanded = []
|
|
for part in parts:
|
|
expanded.extend([item.strip() for item in part.split(",")])
|
|
return [part for part in expanded if part]
|
|
|
|
|
|
def _searxng_tool() -> ChatCompletionToolParam:
|
|
tool_spec = FunctionDefinition(
|
|
name="web_search",
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {"type": "string"},
|
|
"limit": {"type": "integer", "default": 5},
|
|
"language": {"type": "string"},
|
|
"safesearch": {"type": "integer"},
|
|
},
|
|
"required": ["query"],
|
|
},
|
|
description="Search the web via searxng and return top results.",
|
|
)
|
|
return ChatCompletionToolParam(type="function", function=tool_spec)
|
|
|
|
|
|
async def _run_searxng(
|
|
hass: HomeAssistant,
|
|
options: dict[str, Any],
|
|
tool_args: dict[str, Any],
|
|
) -> dict[str, Any]:
|
|
base_url = options.get(CONF_SEARXNG_URL, DEFAULT_SEARXNG_URL).rstrip("/")
|
|
params = {
|
|
"q": tool_args.get("query", ""),
|
|
"format": "json",
|
|
"language": tool_args.get("language") or options.get(CONF_SEARXNG_LANGUAGE, DEFAULT_SEARXNG_LANGUAGE),
|
|
"safesearch": tool_args.get("safesearch", options.get(CONF_SEARXNG_SAFESEARCH, DEFAULT_SEARXNG_SAFESEARCH)),
|
|
}
|
|
limit = tool_args.get("limit", 5)
|
|
try:
|
|
limit = int(limit)
|
|
except (TypeError, ValueError):
|
|
limit = 5
|
|
session = async_get_clientsession(hass)
|
|
async with session.get(f"{base_url}/search", params=params, timeout=20) as resp:
|
|
resp.raise_for_status()
|
|
payload = await resp.json()
|
|
results = []
|
|
for item in payload.get("results", [])[:limit]:
|
|
results.append(
|
|
{
|
|
"title": item.get("title"),
|
|
"url": item.get("url"),
|
|
"content": item.get("content"),
|
|
}
|
|
)
|
|
return {"results": results}
|
|
|
|
|
|
class GroqdConversationEntity(
|
|
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
|
):
|
|
"""groqd conversation agent."""
|
|
|
|
_attr_has_entity_name = True
|
|
_attr_name = None
|
|
|
|
def __init__(self, entry: GroqdConfigEntry) -> None:
|
|
self.entry = entry
|
|
self.history: dict[str, list[ChatCompletionMessageParam]] = {}
|
|
self._memory_index: dict[str, str] = {}
|
|
self._persisted_history: dict[str, list[ChatCompletionMessageParam]] = {}
|
|
self._store: Store | None = None
|
|
self._attr_unique_id = entry.entry_id
|
|
self._attr_device_info = dr.DeviceInfo(
|
|
identifiers={(DOMAIN, entry.entry_id)},
|
|
name=entry.title,
|
|
manufacturer="Groq",
|
|
model="Groq Cloud",
|
|
entry_type=dr.DeviceEntryType.SERVICE,
|
|
)
|
|
if self.entry.options.get(CONF_LLM_HASS_API):
|
|
self._attr_supported_features = (
|
|
conversation.ConversationEntityFeature.CONTROL
|
|
)
|
|
|
|
@property
|
|
def supported_languages(self) -> list[str] | Literal["*"]:
|
|
return MATCH_ALL
|
|
|
|
async def async_added_to_hass(self) -> None:
|
|
await super().async_added_to_hass()
|
|
conversation.async_set_agent(self.hass, self.entry, self)
|
|
self._store = Store(self.hass, 1, f"groqd_history_{self.entry.entry_id}")
|
|
data = await self._store.async_load() or {}
|
|
self._persisted_history = data.get("history", {})
|
|
self._memory_index = data.get("memory_index", {})
|
|
self.entry.async_on_unload(
|
|
self.entry.add_update_listener(self._async_entry_update_listener)
|
|
)
|
|
|
|
async def async_will_remove_from_hass(self) -> None:
|
|
conversation.async_unset_agent(self.hass, self.entry)
|
|
await super().async_will_remove_from_hass()
|
|
|
|
async def async_process(
|
|
self, user_input: conversation.ConversationInput
|
|
) -> conversation.ConversationResult:
|
|
options = self.entry.options
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
llm_api: llm.APIInstance | None = None
|
|
tools: list[ChatCompletionToolParam] | None = None
|
|
user_name: str | None = None
|
|
llm_context = llm.LLMContext(
|
|
platform=DOMAIN,
|
|
context=user_input.context,
|
|
language=user_input.language,
|
|
assistant=conversation.DOMAIN,
|
|
device_id=user_input.device_id,
|
|
)
|
|
|
|
if options.get(CONF_LLM_HASS_API) and options.get(CONF_LLM_HASS_API) != "none":
|
|
try:
|
|
llm_api = await llm.async_get_api(
|
|
self.hass,
|
|
options[CONF_LLM_HASS_API],
|
|
llm_context,
|
|
)
|
|
except HomeAssistantError as err:
|
|
LOGGER.error("Error getting LLM API: %s", err)
|
|
intent_response.async_set_error(
|
|
intent.IntentResponseErrorCode.UNKNOWN,
|
|
f"Error preparing LLM API: {err}",
|
|
)
|
|
return conversation.ConversationResult(
|
|
response=intent_response, conversation_id=user_input.conversation_id
|
|
)
|
|
tools = [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools]
|
|
if options.get(CONF_SEARXNG_ENABLED, DEFAULT_SEARXNG_ENABLED):
|
|
if tools is None:
|
|
tools = []
|
|
tools.append(_searxng_tool())
|
|
|
|
memory_scope = options.get(CONF_MEMORY_SCOPE, DEFAULT_MEMORY_SCOPE)
|
|
memory_key = None
|
|
if memory_scope == "device":
|
|
if user_input.device_id:
|
|
memory_key = f"device:{user_input.device_id}"
|
|
elif user_input.context and user_input.context.user_id:
|
|
memory_key = f"user:{user_input.context.user_id}"
|
|
else:
|
|
memory_key = "global"
|
|
elif memory_scope == "user":
|
|
if user_input.context and user_input.context.user_id:
|
|
memory_key = f"user:{user_input.context.user_id}"
|
|
else:
|
|
memory_key = "global"
|
|
elif memory_scope == "global":
|
|
memory_key = "global"
|
|
|
|
LOGGER.info(
|
|
"Memory scope=%s memory_key=%s conv_id=%s device_id=%s user_id=%s",
|
|
memory_scope,
|
|
memory_key,
|
|
user_input.conversation_id,
|
|
user_input.device_id,
|
|
user_input.context.user_id if user_input.context else None,
|
|
)
|
|
|
|
if memory_scope == "conversation":
|
|
if user_input.conversation_id is not None:
|
|
conversation_id = user_input.conversation_id
|
|
history = self.history.get(conversation_id, [])
|
|
else:
|
|
conversation_id = ulid.ulid_now()
|
|
history = []
|
|
else:
|
|
if memory_key and memory_key in self._memory_index:
|
|
conversation_id = self._memory_index[memory_key]
|
|
history = self._persisted_history.get(memory_key, [])
|
|
else:
|
|
conversation_id = ulid.ulid_now()
|
|
history = []
|
|
|
|
LOGGER.info(
|
|
"Conversation id=%s history_len=%s memory_scope=%s",
|
|
conversation_id,
|
|
len(history),
|
|
memory_scope,
|
|
)
|
|
|
|
if (
|
|
user_input.context
|
|
and user_input.context.user_id
|
|
and (user := await self.hass.auth.async_get_user(user_input.context.user_id))
|
|
):
|
|
user_name = user.name
|
|
|
|
try:
|
|
prompt_parts = [
|
|
template.Template(
|
|
llm.DATE_TIME_PROMPT
|
|
+ options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
|
self.hass,
|
|
).async_render(
|
|
{
|
|
"ha_name": self.hass.config.location_name,
|
|
"user_name": user_name,
|
|
"llm_context": llm_context,
|
|
},
|
|
parse_result=False,
|
|
)
|
|
]
|
|
except TemplateError as err:
|
|
LOGGER.error("Error rendering prompt: %s", err)
|
|
intent_response.async_set_error(
|
|
intent.IntentResponseErrorCode.UNKNOWN,
|
|
f"Template error: {err}",
|
|
)
|
|
return conversation.ConversationResult(
|
|
response=intent_response, conversation_id=conversation_id
|
|
)
|
|
|
|
if llm_api:
|
|
prompt_parts.append(llm_api.api_prompt)
|
|
|
|
prompt = "\n".join(prompt_parts)
|
|
|
|
messages: list[ChatCompletionMessageParam] = [
|
|
ChatCompletionSystemMessageParam(role="system", content=prompt),
|
|
*history,
|
|
ChatCompletionUserMessageParam(role="user", content=user_input.text),
|
|
]
|
|
|
|
trace.async_conversation_trace_append(
|
|
trace.ConversationTraceEventType.AGENT_DETAIL,
|
|
{"messages": messages, "tools": llm_api.tools if llm_api else None},
|
|
)
|
|
|
|
client = self.entry.runtime_data.client
|
|
|
|
tool_choice = _parse_tool_choice(options.get(CONF_TOOL_CHOICE, DEFAULT_TOOL_CHOICE))
|
|
stop_sequences = _parse_stop_sequences(options.get(CONF_STOP))
|
|
|
|
response_format = options.get(CONF_RESPONSE_FORMAT, DEFAULT_RESPONSE_FORMAT)
|
|
response_format_value = (
|
|
{"type": "json_object"} if response_format == "json_object" else NOT_GIVEN
|
|
)
|
|
|
|
max_tokens_value = options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
|
try:
|
|
max_tokens_value = int(max_tokens_value)
|
|
except (TypeError, ValueError):
|
|
max_tokens_value = DEFAULT_MAX_TOKENS
|
|
if max_tokens_value <= 0:
|
|
max_tokens_value = DEFAULT_MAX_TOKENS
|
|
|
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
|
try:
|
|
result = await client.chat.completions.create(
|
|
model=options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL),
|
|
messages=messages,
|
|
tools=tools or NOT_GIVEN,
|
|
tool_choice=tool_choice,
|
|
parallel_tool_calls=options.get(
|
|
CONF_PARALLEL_TOOL_CALLS, DEFAULT_PARALLEL_TOOL_CALLS
|
|
),
|
|
max_tokens=max_tokens_value,
|
|
top_p=options.get(CONF_TOP_P, DEFAULT_TOP_P),
|
|
temperature=options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE),
|
|
frequency_penalty=options.get(
|
|
CONF_FREQUENCY_PENALTY, DEFAULT_FREQUENCY_PENALTY
|
|
),
|
|
presence_penalty=options.get(
|
|
CONF_PRESENCE_PENALTY, DEFAULT_PRESENCE_PENALTY
|
|
),
|
|
seed=options.get(CONF_SEED, NOT_GIVEN),
|
|
stop=stop_sequences or NOT_GIVEN,
|
|
response_format=response_format_value,
|
|
user=conversation_id,
|
|
)
|
|
except groq.GroqError as err:
|
|
intent_response.async_set_error(
|
|
intent.IntentResponseErrorCode.UNKNOWN,
|
|
f"Groq error: {err}",
|
|
)
|
|
return conversation.ConversationResult(
|
|
response=intent_response, conversation_id=conversation_id
|
|
)
|
|
|
|
response = result.choices[0].message
|
|
|
|
def message_convert(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
|
tool_calls: list[ChatCompletionMessageToolCallParam] = []
|
|
if message.tool_calls:
|
|
tool_calls = [
|
|
ChatCompletionMessageToolCallParam(
|
|
id=tool_call.id,
|
|
function=Function(
|
|
arguments=tool_call.function.arguments,
|
|
name=tool_call.function.name,
|
|
),
|
|
type=tool_call.type,
|
|
)
|
|
for tool_call in message.tool_calls
|
|
]
|
|
param = ChatCompletionAssistantMessageParam(
|
|
role=message.role,
|
|
content=message.content,
|
|
)
|
|
if tool_calls:
|
|
param["tool_calls"] = tool_calls
|
|
return param
|
|
|
|
messages.append(message_convert(response))
|
|
tool_calls = response.tool_calls
|
|
|
|
if not tool_calls:
|
|
break
|
|
|
|
for tool_call in tool_calls:
|
|
tool_name = tool_call.function.name
|
|
try:
|
|
tool_args = json.loads(tool_call.function.arguments)
|
|
except json.JSONDecodeError:
|
|
tool_args = {}
|
|
|
|
if tool_name == "web_search":
|
|
try:
|
|
tool_response = await _run_searxng(self.hass, options, tool_args)
|
|
except Exception as err:
|
|
tool_response = {"error": type(err).__name__, "error_text": str(err)}
|
|
elif llm_api:
|
|
tool_input = llm.ToolInput(
|
|
tool_name=tool_name,
|
|
tool_args=tool_args,
|
|
)
|
|
try:
|
|
tool_response = await llm_api.async_call_tool(tool_input)
|
|
except (HomeAssistantError, vol.Invalid) as err:
|
|
tool_response = {"error": type(err).__name__}
|
|
if str(err):
|
|
tool_response["error_text"] = str(err)
|
|
else:
|
|
tool_response = {"error": "ToolNotAvailable", "error_text": tool_name}
|
|
|
|
messages.append(
|
|
ChatCompletionToolMessageParam(
|
|
role="tool",
|
|
tool_call_id=tool_call.id,
|
|
content=json.dumps(tool_response),
|
|
)
|
|
)
|
|
|
|
history = messages[1:]
|
|
limit = options.get(CONF_CONTEXT_MESSAGES, DEFAULT_CONTEXT_MESSAGES)
|
|
try:
|
|
limit = int(limit)
|
|
except (TypeError, ValueError):
|
|
limit = DEFAULT_CONTEXT_MESSAGES
|
|
if limit == 0:
|
|
history = []
|
|
elif limit > 0:
|
|
history = history[-limit:]
|
|
|
|
if memory_scope == "conversation":
|
|
self.history[conversation_id] = history
|
|
else:
|
|
if memory_key:
|
|
self._persisted_history[memory_key] = history
|
|
self._memory_index[memory_key] = conversation_id
|
|
if self._store:
|
|
await self._store.async_save(
|
|
{"history": self._persisted_history, "memory_index": self._memory_index}
|
|
)
|
|
LOGGER.info(
|
|
"Persisted memory_key=%s history_len=%s conv_id=%s",
|
|
memory_key,
|
|
len(history),
|
|
conversation_id,
|
|
)
|
|
|
|
intent_response.async_set_speech(response.content or "")
|
|
return conversation.ConversationResult(
|
|
response=intent_response, conversation_id=conversation_id
|
|
)
|
|
|
|
async def _async_entry_update_listener(
|
|
self, hass: HomeAssistant, entry: ConfigEntry
|
|
) -> None:
|
|
await hass.config_entries.async_reload(entry.entry_id)
|