groqd/custom_components/groqd/conversation.py
2025-12-20 11:39:38 -06:00

368 lines
13 KiB
Python

"""Conversation support for groqd."""
from __future__ import annotations
from collections.abc import Callable
import json
from typing import Any, Literal
from groq._types import NOT_GIVEN
from groq.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessage,
ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionToolParam,
ChatCompletionUserMessageParam,
)
from groq.types.chat.chat_completion_message_tool_call_param import Function
from groq.types.shared_params import FunctionDefinition
import groq
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import conversation
from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import device_registry as dr, intent, llm, template
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid
from . import GroqdConfigEntry
from .const import (
CONF_CHAT_MODEL,
CONF_CONTEXT_MESSAGES,
CONF_FREQUENCY_PENALTY,
CONF_MAX_TOKENS,
CONF_PARALLEL_TOOL_CALLS,
CONF_PRESENCE_PENALTY,
CONF_PROMPT,
CONF_RESPONSE_FORMAT,
CONF_SEED,
CONF_STOP,
CONF_TEMPERATURE,
CONF_TOOL_CHOICE,
CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_CONTEXT_MESSAGES,
DEFAULT_FREQUENCY_PENALTY,
DEFAULT_MAX_TOKENS,
DEFAULT_PARALLEL_TOOL_CALLS,
DEFAULT_PRESENCE_PENALTY,
DEFAULT_RESPONSE_FORMAT,
DEFAULT_TEMPERATURE,
DEFAULT_TOOL_CHOICE,
DEFAULT_TOP_P,
DOMAIN,
LOGGER,
)
MAX_TOOL_ITERATIONS = 10
async def async_setup_entry(
hass: HomeAssistant,
config_entry: GroqdConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up conversation entities."""
agent = GroqdConversationEntity(config_entry)
async_add_entities([agent])
def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> ChatCompletionToolParam:
"""Format tool specification."""
tool_spec = FunctionDefinition(
name=tool.name,
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
)
if tool.description:
tool_spec["description"] = tool.description
return ChatCompletionToolParam(type="function", function=tool_spec)
def _parse_tool_choice(value: str) -> Any:
"""Parse tool choice option into Groq-compatible value."""
value = (value or "").strip()
if not value:
return NOT_GIVEN
if value in {"auto", "none", "required"}:
return value
if value.startswith("tool:"):
name = value.split(":", 1)[1].strip()
if name:
return {"type": "function", "function": {"name": name}}
return value
def _parse_stop_sequences(value: str | None) -> list[str] | None:
if not value:
return None
parts = [part.strip() for part in value.replace("\r", "").split("\n")]
expanded = []
for part in parts:
expanded.extend([item.strip() for item in part.split(",")])
return [part for part in expanded if part]
class GroqdConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
"""groqd conversation agent."""
_attr_has_entity_name = True
_attr_name = None
def __init__(self, entry: GroqdConfigEntry) -> None:
self.entry = entry
self.history: dict[str, list[ChatCompletionMessageParam]] = {}
self._attr_unique_id = entry.entry_id
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)},
name=entry.title,
manufacturer="Groq",
model="Groq Cloud",
entry_type=dr.DeviceEntryType.SERVICE,
)
if self.entry.options.get(CONF_LLM_HASS_API):
self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL
)
@property
def supported_languages(self) -> list[str] | Literal["*"]:
return MATCH_ALL
async def async_added_to_hass(self) -> None:
await super().async_added_to_hass()
conversation.async_set_agent(self.hass, self.entry, self)
self.entry.async_on_unload(
self.entry.add_update_listener(self._async_entry_update_listener)
)
async def async_will_remove_from_hass(self) -> None:
conversation.async_unset_agent(self.hass, self.entry)
await super().async_will_remove_from_hass()
async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
options = self.entry.options
intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.APIInstance | None = None
tools: list[ChatCompletionToolParam] | None = None
user_name: str | None = None
llm_context = llm.LLMContext(
platform=DOMAIN,
context=user_input.context,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
if options.get(CONF_LLM_HASS_API) and options.get(CONF_LLM_HASS_API) != "none":
try:
llm_api = await llm.async_get_api(
self.hass,
options[CONF_LLM_HASS_API],
llm_context,
)
except HomeAssistantError as err:
LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Error preparing LLM API: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
tools = [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools]
if user_input.conversation_id is None:
conversation_id = ulid.ulid_now()
history = []
elif user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
history = self.history[conversation_id]
else:
try:
ulid.ulid_to_bytes(user_input.conversation_id)
conversation_id = ulid.ulid_now()
except ValueError:
conversation_id = user_input.conversation_id
history = []
if (
user_input.context
and user_input.context.user_id
and (user := await self.hass.auth.async_get_user(user_input.context.user_id))
):
user_name = user.name
try:
prompt_parts = [
template.Template(
llm.DATE_TIME_PROMPT
+ options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
self.hass,
).async_render(
{
"ha_name": self.hass.config.location_name,
"user_name": user_name,
"llm_context": llm_context,
},
parse_result=False,
)
]
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Template error: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
if llm_api:
prompt_parts.append(llm_api.api_prompt)
prompt = "\n".join(prompt_parts)
messages: list[ChatCompletionMessageParam] = [
ChatCompletionSystemMessageParam(role="system", content=prompt),
*history,
ChatCompletionUserMessageParam(role="user", content=user_input.text),
]
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
{"messages": messages, "tools": llm_api.tools if llm_api else None},
)
client = self.entry.runtime_data.client
tool_choice = _parse_tool_choice(options.get(CONF_TOOL_CHOICE, DEFAULT_TOOL_CHOICE))
stop_sequences = _parse_stop_sequences(options.get(CONF_STOP))
response_format = options.get(CONF_RESPONSE_FORMAT, DEFAULT_RESPONSE_FORMAT)
response_format_value = (
{"type": "json_object"} if response_format == "json_object" else NOT_GIVEN
)
max_tokens_value = options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
try:
max_tokens_value = int(max_tokens_value)
except (TypeError, ValueError):
max_tokens_value = DEFAULT_MAX_TOKENS
if max_tokens_value <= 0:
max_tokens_value = DEFAULT_MAX_TOKENS
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
result = await client.chat.completions.create(
model=options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL),
messages=messages,
tools=tools or NOT_GIVEN,
tool_choice=tool_choice,
parallel_tool_calls=options.get(
CONF_PARALLEL_TOOL_CALLS, DEFAULT_PARALLEL_TOOL_CALLS
),
max_tokens=max_tokens_value,
top_p=options.get(CONF_TOP_P, DEFAULT_TOP_P),
temperature=options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE),
frequency_penalty=options.get(
CONF_FREQUENCY_PENALTY, DEFAULT_FREQUENCY_PENALTY
),
presence_penalty=options.get(
CONF_PRESENCE_PENALTY, DEFAULT_PRESENCE_PENALTY
),
seed=options.get(CONF_SEED, NOT_GIVEN),
stop=stop_sequences or NOT_GIVEN,
response_format=response_format_value,
user=conversation_id,
)
except groq.GroqError as err:
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Groq error: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
response = result.choices[0].message
def message_convert(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
tool_calls: list[ChatCompletionMessageToolCallParam] = []
if message.tool_calls:
tool_calls = [
ChatCompletionMessageToolCallParam(
id=tool_call.id,
function=Function(
arguments=tool_call.function.arguments,
name=tool_call.function.name,
),
type=tool_call.type,
)
for tool_call in message.tool_calls
]
param = ChatCompletionAssistantMessageParam(
role=message.role,
content=message.content,
)
if tool_calls:
param["tool_calls"] = tool_calls
return param
messages.append(message_convert(response))
tool_calls = response.tool_calls
if not tool_calls or not llm_api:
break
for tool_call in tool_calls:
tool_input = llm.ToolInput(
tool_name=tool_call.function.name,
tool_args=json.loads(tool_call.function.arguments),
)
try:
tool_response = await llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as err:
tool_response = {"error": type(err).__name__}
if str(err):
tool_response["error_text"] = str(err)
messages.append(
ChatCompletionToolMessageParam(
role="tool",
tool_call_id=tool_call.id,
content=json.dumps(tool_response),
)
)
history = messages[1:]
limit = options.get(CONF_CONTEXT_MESSAGES, DEFAULT_CONTEXT_MESSAGES)
if limit == 0:
history = []
elif limit > 0:
history = history[-limit:]
self.history[conversation_id] = history
intent_response.async_set_speech(response.content or "")
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
async def _async_entry_update_listener(
self, hass: HomeAssistant, entry: ConfigEntry
) -> None:
await hass.config_entries.async_reload(entry.entry_id)