mirror of
https://github.com/sudoxreboot/groqd
synced 2026-04-14 03:26:35 +00:00
Initial groqd integration
This commit is contained in:
commit
9ff2868b02
8 changed files with 782 additions and 0 deletions
19
README.md
Normal file
19
README.md
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# groqd
|
||||
|
||||
Home Assistant custom integration for Groq Cloud with full, configurable LLM options.
|
||||
|
||||
## Features
|
||||
- Multiple instances (different models, personalities, settings)
|
||||
- Tool calling with Home Assistant LLM API integration
|
||||
- Configurable prompts, context limits, and generation parameters
|
||||
- Options flow for post-setup edits
|
||||
|
||||
## Install (HACS)
|
||||
1. Add this repository as a custom repository in HACS (type: Integration).
|
||||
2. Install `groqd`.
|
||||
3. Restart Home Assistant.
|
||||
4. Add the integration from **Settings → Devices & Services**.
|
||||
|
||||
## Notes
|
||||
- Bring your own Groq API key.
|
||||
- TTS is provided by Home Assistant and other integrations.
|
||||
39
custom_components/groqd/__init__.py
Normal file
39
custom_components/groqd/__init__.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
"""groqd integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import groq
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_API_KEY, Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
PLATFORMS: list[Platform] = [Platform.CONVERSATION]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroqdRuntimeData:
|
||||
"""Runtime data for groqd."""
|
||||
|
||||
client: groq.AsyncClient
|
||||
|
||||
|
||||
type GroqdConfigEntry = ConfigEntry[GroqdRuntimeData]
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: GroqdConfigEntry) -> bool:
|
||||
"""Set up groqd from a config entry."""
|
||||
api_key = entry.data[CONF_API_KEY]
|
||||
entry.runtime_data = GroqdRuntimeData(client=groq.AsyncClient(api_key=api_key))
|
||||
|
||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: GroqdConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||
258
custom_components/groqd/config_flow.py
Normal file
258
custom_components/groqd/config_flow.py
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
"""Config flow for groqd."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import exceptions
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow, ConfigFlowResult, OptionsFlow
|
||||
from homeassistant.const import CONF_API_KEY, CONF_NAME, CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.selector import (
|
||||
NumberSelector,
|
||||
NumberSelectorConfig,
|
||||
SelectOptionDict,
|
||||
SelectSelector,
|
||||
SelectSelectorConfig,
|
||||
TemplateSelector,
|
||||
)
|
||||
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_CONTEXT_MESSAGES,
|
||||
CONF_FREQUENCY_PENALTY,
|
||||
CONF_PARALLEL_TOOL_CALLS,
|
||||
CONF_PRESENCE_PENALTY,
|
||||
CONF_PROMPT,
|
||||
CONF_RESPONSE_FORMAT,
|
||||
CONF_SEED,
|
||||
CONF_STOP,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOOL_CHOICE,
|
||||
CONF_TOP_P,
|
||||
DEFAULT_CHAT_MODEL,
|
||||
DEFAULT_CONTEXT_MESSAGES,
|
||||
DEFAULT_FREQUENCY_PENALTY,
|
||||
DEFAULT_NAME,
|
||||
DEFAULT_PARALLEL_TOOL_CALLS,
|
||||
DEFAULT_PRESENCE_PENALTY,
|
||||
DEFAULT_RESPONSE_FORMAT,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOOL_CHOICE,
|
||||
DEFAULT_TOP_P,
|
||||
DOMAIN,
|
||||
)
|
||||
from .const import DEFAULT_MAX_TOKENS, CONF_MAX_TOKENS
|
||||
from .const import LOGGER
|
||||
|
||||
STEP_USER_DATA_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_NAME, default=DEFAULT_NAME): cv.string,
|
||||
vol.Required(CONF_API_KEY): cv.string,
|
||||
vol.Required(CONF_CHAT_MODEL, default=DEFAULT_CHAT_MODEL): cv.string,
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
default=llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
||||
): TemplateSelector(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_models(api_key: str, hass: HomeAssistant) -> list[str]:
|
||||
session = async_get_clientsession(hass)
|
||||
async with session.get(
|
||||
"https://api.groq.com/openai/v1/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
timeout=20,
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise InvalidAPIKey
|
||||
if response.status == 403:
|
||||
raise UnauthorizedError
|
||||
if response.status != 200:
|
||||
raise UnknownError
|
||||
data = await response.json()
|
||||
return [item.get("id") for item in data.get("data", []) if item.get("id")]
|
||||
|
||||
|
||||
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
|
||||
"""Validate user input."""
|
||||
obscured_api_key = data.get(CONF_API_KEY)
|
||||
data[CONF_API_KEY] = "<api_key>"
|
||||
LOGGER.debug("User validation got: %s", data)
|
||||
data[CONF_API_KEY] = obscured_api_key
|
||||
|
||||
models = await _fetch_models(data.get(CONF_API_KEY), hass)
|
||||
if not models:
|
||||
raise UnknownError
|
||||
|
||||
model = data.get(CONF_CHAT_MODEL)
|
||||
if model not in models:
|
||||
LOGGER.warning("Model not found: %s. Available: %s", model, models)
|
||||
raise ModelNotFound
|
||||
|
||||
|
||||
class GroqdConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for groqd."""
|
||||
|
||||
VERSION = 1
|
||||
MINOR_VERSION = 0
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
if user_input is None:
|
||||
return self.async_show_form(step_id="user", data_schema=STEP_USER_DATA_SCHEMA)
|
||||
|
||||
errors: dict[str, str] = {}
|
||||
try:
|
||||
await validate_input(self.hass, user_input)
|
||||
except InvalidAPIKey:
|
||||
errors["base"] = "invalid_auth"
|
||||
except UnauthorizedError:
|
||||
errors["base"] = "unauthorized"
|
||||
except ModelNotFound:
|
||||
errors[CONF_CHAT_MODEL] = "model_not_found"
|
||||
except Exception:
|
||||
LOGGER.exception("Unexpected exception")
|
||||
errors["base"] = "unknown"
|
||||
else:
|
||||
return self.async_create_entry(title=user_input[CONF_NAME], data=user_input)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def async_get_options_flow(config_entry: ConfigEntry) -> OptionsFlow:
|
||||
return GroqdOptionsFlow(config_entry)
|
||||
|
||||
|
||||
class GroqdOptionsFlow(OptionsFlow):
|
||||
"""Options flow for groqd."""
|
||||
|
||||
def __init__(self, config_entry: ConfigEntry) -> None:
|
||||
self.config_entry = config_entry
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
options = dict(self.config_entry.options)
|
||||
|
||||
if user_input is not None:
|
||||
api_key = user_input.pop(CONF_API_KEY, "")
|
||||
if api_key:
|
||||
self.hass.config_entries.async_update_entry(
|
||||
self.config_entry,
|
||||
data={**self.config_entry.data, CONF_API_KEY: api_key},
|
||||
)
|
||||
return self.async_create_entry(title="", data=user_input)
|
||||
|
||||
hass_apis: list[SelectOptionDict] = [
|
||||
SelectOptionDict(label="No control", value="none")
|
||||
]
|
||||
hass_apis.extend(
|
||||
SelectOptionDict(label=api.name, value=api.id)
|
||||
for api in llm.async_get_apis(self.hass)
|
||||
)
|
||||
|
||||
schema = {
|
||||
vol.Optional(CONF_API_KEY): cv.string,
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT)},
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_CHAT_MODEL,
|
||||
description={"suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)},
|
||||
default=options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL),
|
||||
): cv.string,
|
||||
vol.Optional(
|
||||
CONF_CONTEXT_MESSAGES,
|
||||
description={"suggested_value": options.get(CONF_CONTEXT_MESSAGES, DEFAULT_CONTEXT_MESSAGES)},
|
||||
default=options.get(CONF_CONTEXT_MESSAGES, DEFAULT_CONTEXT_MESSAGES),
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=200, step=1)),
|
||||
vol.Optional(
|
||||
CONF_MAX_TOKENS,
|
||||
description={"suggested_value": options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)},
|
||||
default=options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS),
|
||||
): NumberSelector(NumberSelectorConfig(min=1, max=8192, step=1)),
|
||||
vol.Optional(
|
||||
CONF_TEMPERATURE,
|
||||
description={"suggested_value": options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)},
|
||||
default=options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE),
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
|
||||
vol.Optional(
|
||||
CONF_TOP_P,
|
||||
description={"suggested_value": options.get(CONF_TOP_P, DEFAULT_TOP_P)},
|
||||
default=options.get(CONF_TOP_P, DEFAULT_TOP_P),
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||
vol.Optional(
|
||||
CONF_FREQUENCY_PENALTY,
|
||||
description={"suggested_value": options.get(CONF_FREQUENCY_PENALTY, DEFAULT_FREQUENCY_PENALTY)},
|
||||
default=options.get(CONF_FREQUENCY_PENALTY, DEFAULT_FREQUENCY_PENALTY),
|
||||
): NumberSelector(NumberSelectorConfig(min=-2, max=2, step=0.05)),
|
||||
vol.Optional(
|
||||
CONF_PRESENCE_PENALTY,
|
||||
description={"suggested_value": options.get(CONF_PRESENCE_PENALTY, DEFAULT_PRESENCE_PENALTY)},
|
||||
default=options.get(CONF_PRESENCE_PENALTY, DEFAULT_PRESENCE_PENALTY),
|
||||
): NumberSelector(NumberSelectorConfig(min=-2, max=2, step=0.05)),
|
||||
vol.Optional(
|
||||
CONF_SEED,
|
||||
description={"suggested_value": options.get(CONF_SEED)},
|
||||
): cv.positive_int,
|
||||
vol.Optional(
|
||||
CONF_STOP,
|
||||
description={"suggested_value": options.get(CONF_STOP, "")},
|
||||
): cv.string,
|
||||
vol.Optional(
|
||||
CONF_TOOL_CHOICE,
|
||||
description={"suggested_value": options.get(CONF_TOOL_CHOICE, DEFAULT_TOOL_CHOICE)},
|
||||
default=options.get(CONF_TOOL_CHOICE, DEFAULT_TOOL_CHOICE),
|
||||
): cv.string,
|
||||
vol.Optional(
|
||||
CONF_PARALLEL_TOOL_CALLS,
|
||||
description={"suggested_value": options.get(CONF_PARALLEL_TOOL_CALLS, DEFAULT_PARALLEL_TOOL_CALLS)},
|
||||
default=options.get(CONF_PARALLEL_TOOL_CALLS, DEFAULT_PARALLEL_TOOL_CALLS),
|
||||
): bool,
|
||||
vol.Optional(
|
||||
CONF_RESPONSE_FORMAT,
|
||||
description={"suggested_value": options.get(CONF_RESPONSE_FORMAT, DEFAULT_RESPONSE_FORMAT)},
|
||||
default=options.get(CONF_RESPONSE_FORMAT, DEFAULT_RESPONSE_FORMAT),
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(
|
||||
options=[
|
||||
SelectOptionDict(label="text", value="text"),
|
||||
SelectOptionDict(label="json_object", value="json_object"),
|
||||
]
|
||||
)
|
||||
),
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||
default=options.get(CONF_LLM_HASS_API, "none"),
|
||||
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
|
||||
}
|
||||
|
||||
return self.async_show_form(step_id="init", data_schema=vol.Schema(schema))
|
||||
|
||||
|
||||
class UnknownError(exceptions.HomeAssistantError):
|
||||
"""Unknown error."""
|
||||
|
||||
|
||||
class UnauthorizedError(exceptions.HomeAssistantError):
|
||||
"""API key valid but not authorized."""
|
||||
|
||||
|
||||
class InvalidAPIKey(exceptions.HomeAssistantError):
|
||||
"""Invalid API key error."""
|
||||
|
||||
|
||||
class ModelNotFound(exceptions.HomeAssistantError):
|
||||
"""Model not found in Groq model list."""
|
||||
35
custom_components/groqd/const.py
Normal file
35
custom_components/groqd/const.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
"""Constants for groqd."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
DOMAIN = "groqd"
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_NAME = "groqd"
|
||||
|
||||
CONF_CHAT_MODEL = "chat_model"
|
||||
CONF_PROMPT = "prompt"
|
||||
CONF_CONTEXT_MESSAGES = "context_messages"
|
||||
CONF_MAX_TOKENS = "max_tokens"
|
||||
CONF_TEMPERATURE = "temperature"
|
||||
CONF_TOP_P = "top_p"
|
||||
CONF_FREQUENCY_PENALTY = "frequency_penalty"
|
||||
CONF_PRESENCE_PENALTY = "presence_penalty"
|
||||
CONF_SEED = "seed"
|
||||
CONF_STOP = "stop_sequences"
|
||||
CONF_TOOL_CHOICE = "tool_choice"
|
||||
CONF_PARALLEL_TOOL_CALLS = "parallel_tool_calls"
|
||||
CONF_RESPONSE_FORMAT = "response_format"
|
||||
|
||||
DEFAULT_CHAT_MODEL = "meta-llama/llama-4-maverick-17b-128e-instruct"
|
||||
DEFAULT_CONTEXT_MESSAGES = 20
|
||||
DEFAULT_MAX_TOKENS = 512
|
||||
DEFAULT_TEMPERATURE = 1.0
|
||||
DEFAULT_TOP_P = 1.0
|
||||
DEFAULT_FREQUENCY_PENALTY = 0.0
|
||||
DEFAULT_PRESENCE_PENALTY = 0.0
|
||||
DEFAULT_TOOL_CHOICE = "auto"
|
||||
DEFAULT_PARALLEL_TOOL_CALLS = True
|
||||
DEFAULT_RESPONSE_FORMAT = "text"
|
||||
360
custom_components/groqd/conversation.py
Normal file
360
custom_components/groqd/conversation.py
Normal file
|
|
@ -0,0 +1,360 @@
|
|||
"""Conversation support for groqd."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import json
|
||||
from typing import Any, Literal
|
||||
|
||||
from groq._types import NOT_GIVEN
|
||||
from groq.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from groq.types.chat.chat_completion_message_tool_call_param import Function
|
||||
from groq.types.shared_params import FunctionDefinition
|
||||
import groq
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||
from homeassistant.helpers import device_registry as dr, intent, llm, template
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.util import ulid
|
||||
|
||||
from . import GroqdConfigEntry
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_CONTEXT_MESSAGES,
|
||||
CONF_FREQUENCY_PENALTY,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PARALLEL_TOOL_CALLS,
|
||||
CONF_PRESENCE_PENALTY,
|
||||
CONF_PROMPT,
|
||||
CONF_RESPONSE_FORMAT,
|
||||
CONF_SEED,
|
||||
CONF_STOP,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOOL_CHOICE,
|
||||
CONF_TOP_P,
|
||||
DEFAULT_CHAT_MODEL,
|
||||
DEFAULT_CONTEXT_MESSAGES,
|
||||
DEFAULT_FREQUENCY_PENALTY,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_PARALLEL_TOOL_CALLS,
|
||||
DEFAULT_PRESENCE_PENALTY,
|
||||
DEFAULT_RESPONSE_FORMAT,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOOL_CHOICE,
|
||||
DEFAULT_TOP_P,
|
||||
DOMAIN,
|
||||
LOGGER,
|
||||
)
|
||||
|
||||
MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: GroqdConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up conversation entities."""
|
||||
agent = GroqdConversationEntity(config_entry)
|
||||
async_add_entities([agent])
|
||||
|
||||
|
||||
def _format_tool(
|
||||
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||
) -> ChatCompletionToolParam:
|
||||
"""Format tool specification."""
|
||||
tool_spec = FunctionDefinition(
|
||||
name=tool.name,
|
||||
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
|
||||
)
|
||||
if tool.description:
|
||||
tool_spec["description"] = tool.description
|
||||
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||
|
||||
|
||||
def _parse_tool_choice(value: str) -> Any:
|
||||
"""Parse tool choice option into Groq-compatible value."""
|
||||
value = (value or "").strip()
|
||||
if not value:
|
||||
return NOT_GIVEN
|
||||
if value in {"auto", "none", "required"}:
|
||||
return value
|
||||
if value.startswith("tool:"):
|
||||
name = value.split(":", 1)[1].strip()
|
||||
if name:
|
||||
return {"type": "function", "function": {"name": name}}
|
||||
return value
|
||||
|
||||
|
||||
def _parse_stop_sequences(value: str | None) -> list[str] | None:
|
||||
if not value:
|
||||
return None
|
||||
parts = [part.strip() for part in value.replace("\r", "").split("\n")]
|
||||
expanded = []
|
||||
for part in parts:
|
||||
expanded.extend([item.strip() for item in part.split(",")])
|
||||
return [part for part in expanded if part]
|
||||
|
||||
|
||||
class GroqdConversationEntity(
|
||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||
):
|
||||
"""groqd conversation agent."""
|
||||
|
||||
_attr_has_entity_name = True
|
||||
_attr_name = None
|
||||
|
||||
def __init__(self, entry: GroqdConfigEntry) -> None:
|
||||
self.entry = entry
|
||||
self.history: dict[str, list[ChatCompletionMessageParam]] = {}
|
||||
self._attr_unique_id = entry.entry_id
|
||||
self._attr_device_info = dr.DeviceInfo(
|
||||
identifiers={(DOMAIN, entry.entry_id)},
|
||||
name=entry.title,
|
||||
manufacturer="Groq",
|
||||
model="Groq Cloud",
|
||||
entry_type=dr.DeviceEntryType.SERVICE,
|
||||
)
|
||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||
self._attr_supported_features = (
|
||||
conversation.ConversationEntityFeature.CONTROL
|
||||
)
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str] | Literal["*"]:
|
||||
return MATCH_ALL
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
await super().async_added_to_hass()
|
||||
conversation.async_set_agent(self.hass, self.entry, self)
|
||||
self.entry.async_on_unload(
|
||||
self.entry.add_update_listener(self._async_entry_update_listener)
|
||||
)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
conversation.async_unset_agent(self.hass, self.entry)
|
||||
await super().async_will_remove_from_hass()
|
||||
|
||||
async def async_process(
|
||||
self, user_input: conversation.ConversationInput
|
||||
) -> conversation.ConversationResult:
|
||||
options = self.entry.options
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
llm_api: llm.APIInstance | None = None
|
||||
tools: list[ChatCompletionToolParam] | None = None
|
||||
user_name: str | None = None
|
||||
llm_context = llm.LLMContext(
|
||||
platform=DOMAIN,
|
||||
context=user_input.context,
|
||||
language=user_input.language,
|
||||
assistant=conversation.DOMAIN,
|
||||
device_id=user_input.device_id,
|
||||
)
|
||||
|
||||
if options.get(CONF_LLM_HASS_API) and options.get(CONF_LLM_HASS_API) != "none":
|
||||
try:
|
||||
llm_api = await llm.async_get_api(
|
||||
self.hass,
|
||||
options[CONF_LLM_HASS_API],
|
||||
llm_context,
|
||||
)
|
||||
except HomeAssistantError as err:
|
||||
LOGGER.error("Error getting LLM API: %s", err)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Error preparing LLM API: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
tools = [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools]
|
||||
|
||||
if user_input.conversation_id is None:
|
||||
conversation_id = ulid.ulid_now()
|
||||
history = []
|
||||
elif user_input.conversation_id in self.history:
|
||||
conversation_id = user_input.conversation_id
|
||||
history = self.history[conversation_id]
|
||||
else:
|
||||
try:
|
||||
ulid.ulid_to_bytes(user_input.conversation_id)
|
||||
conversation_id = ulid.ulid_now()
|
||||
except ValueError:
|
||||
conversation_id = user_input.conversation_id
|
||||
history = []
|
||||
|
||||
if (
|
||||
user_input.context
|
||||
and user_input.context.user_id
|
||||
and (user := await self.hass.auth.async_get_user(user_input.context.user_id))
|
||||
):
|
||||
user_name = user.name
|
||||
|
||||
try:
|
||||
prompt_parts = [
|
||||
template.Template(
|
||||
llm.DATE_TIME_PROMPT
|
||||
+ options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
||||
self.hass,
|
||||
).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
"user_name": user_name,
|
||||
"llm_context": llm_context,
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
]
|
||||
except TemplateError as err:
|
||||
LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Template error: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
if llm_api:
|
||||
prompt_parts.append(llm_api.api_prompt)
|
||||
|
||||
prompt = "\n".join(prompt_parts)
|
||||
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
ChatCompletionSystemMessageParam(role="system", content=prompt),
|
||||
*history,
|
||||
ChatCompletionUserMessageParam(role="user", content=user_input.text),
|
||||
]
|
||||
|
||||
trace.async_conversation_trace_append(
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
{"messages": messages, "tools": llm_api.tools if llm_api else None},
|
||||
)
|
||||
|
||||
client = self.entry.runtime_data.client
|
||||
|
||||
tool_choice = _parse_tool_choice(options.get(CONF_TOOL_CHOICE, DEFAULT_TOOL_CHOICE))
|
||||
stop_sequences = _parse_stop_sequences(options.get(CONF_STOP))
|
||||
|
||||
response_format = options.get(CONF_RESPONSE_FORMAT, DEFAULT_RESPONSE_FORMAT)
|
||||
response_format_value = (
|
||||
{"type": "json_object"} if response_format == "json_object" else NOT_GIVEN
|
||||
)
|
||||
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
result = await client.chat.completions.create(
|
||||
model=options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL),
|
||||
messages=messages,
|
||||
tools=tools or NOT_GIVEN,
|
||||
tool_choice=tool_choice,
|
||||
parallel_tool_calls=options.get(
|
||||
CONF_PARALLEL_TOOL_CALLS, DEFAULT_PARALLEL_TOOL_CALLS
|
||||
),
|
||||
max_tokens=options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS),
|
||||
top_p=options.get(CONF_TOP_P, DEFAULT_TOP_P),
|
||||
temperature=options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE),
|
||||
frequency_penalty=options.get(
|
||||
CONF_FREQUENCY_PENALTY, DEFAULT_FREQUENCY_PENALTY
|
||||
),
|
||||
presence_penalty=options.get(
|
||||
CONF_PRESENCE_PENALTY, DEFAULT_PRESENCE_PENALTY
|
||||
),
|
||||
seed=options.get(CONF_SEED, NOT_GIVEN),
|
||||
stop=stop_sequences or NOT_GIVEN,
|
||||
response_format=response_format_value,
|
||||
user=conversation_id,
|
||||
)
|
||||
except groq.GroqError as err:
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Groq error: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
response = result.choices[0].message
|
||||
|
||||
def message_convert(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
||||
tool_calls: list[ChatCompletionMessageToolCallParam] = []
|
||||
if message.tool_calls:
|
||||
tool_calls = [
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=tool_call.id,
|
||||
function=Function(
|
||||
arguments=tool_call.function.arguments,
|
||||
name=tool_call.function.name,
|
||||
),
|
||||
type=tool_call.type,
|
||||
)
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
param = ChatCompletionAssistantMessageParam(
|
||||
role=message.role,
|
||||
content=message.content,
|
||||
)
|
||||
if tool_calls:
|
||||
param["tool_calls"] = tool_calls
|
||||
return param
|
||||
|
||||
messages.append(message_convert(response))
|
||||
tool_calls = response.tool_calls
|
||||
|
||||
if not tool_calls or not llm_api:
|
||||
break
|
||||
|
||||
for tool_call in tool_calls:
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=tool_call.function.name,
|
||||
tool_args=json.loads(tool_call.function.arguments),
|
||||
)
|
||||
try:
|
||||
tool_response = await llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as err:
|
||||
tool_response = {"error": type(err).__name__}
|
||||
if str(err):
|
||||
tool_response["error_text"] = str(err)
|
||||
|
||||
messages.append(
|
||||
ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
tool_call_id=tool_call.id,
|
||||
content=json.dumps(tool_response),
|
||||
)
|
||||
)
|
||||
|
||||
history = messages[1:]
|
||||
limit = options.get(CONF_CONTEXT_MESSAGES, DEFAULT_CONTEXT_MESSAGES)
|
||||
if limit == 0:
|
||||
history = []
|
||||
elif limit > 0:
|
||||
history = history[-limit:]
|
||||
|
||||
self.history[conversation_id] = history
|
||||
|
||||
intent_response.async_set_speech(response.content or "")
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
async def _async_entry_update_listener(
|
||||
self, hass: HomeAssistant, entry: ConfigEntry
|
||||
) -> None:
|
||||
await hass.config_entries.async_reload(entry.entry_id)
|
||||
14
custom_components/groqd/manifest.json
Normal file
14
custom_components/groqd/manifest.json
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
{
|
||||
"domain": "groqd",
|
||||
"name": "groqd",
|
||||
"version": "0.1.0",
|
||||
"config_flow": true,
|
||||
"documentation": "https://github.com/sudoxnym/groqd",
|
||||
"issue_tracker": "https://github.com/sudoxnym/groqd/issues",
|
||||
"codeowners": ["@sudoxnym"],
|
||||
"dependencies": ["conversation"],
|
||||
"iot_class": "cloud_push",
|
||||
"requirements": [
|
||||
"groq==0.18.0"
|
||||
]
|
||||
}
|
||||
51
custom_components/groqd/translations/en.json
Normal file
51
custom_components/groqd/translations/en.json
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
{
|
||||
"config": {
|
||||
"step": {
|
||||
"user": {
|
||||
"title": "groqd",
|
||||
"data": {
|
||||
"name": "Name",
|
||||
"api_key": "API key",
|
||||
"chat_model": "Model",
|
||||
"prompt": "Personality prompt"
|
||||
},
|
||||
"data_description": {
|
||||
"prompt": "System prompt for the assistant. Supports templates."
|
||||
}
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
"invalid_auth": "Invalid API key",
|
||||
"unauthorized": "API key not authorized",
|
||||
"model_not_found": "Model not found for this key",
|
||||
"unknown": "Unknown error"
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
"step": {
|
||||
"init": {
|
||||
"title": "groqd options",
|
||||
"data": {
|
||||
"api_key": "Update API key (optional)",
|
||||
"prompt": "Personality prompt",
|
||||
"chat_model": "Model",
|
||||
"context_messages": "Context messages (0 = no history)",
|
||||
"max_tokens": "Max tokens",
|
||||
"temperature": "Temperature",
|
||||
"top_p": "Top P",
|
||||
"frequency_penalty": "Frequency penalty",
|
||||
"presence_penalty": "Presence penalty",
|
||||
"seed": "Seed",
|
||||
"stop_sequences": "Stop sequences (comma or newline separated)",
|
||||
"tool_choice": "Tool choice (auto/none/required/tool:<name>)",
|
||||
"parallel_tool_calls": "Parallel tool calls",
|
||||
"response_format": "Response format",
|
||||
"llm_hass_api": "Home Assistant LLM API"
|
||||
},
|
||||
"data_description": {
|
||||
"prompt": "System prompt for the assistant. Supports templates."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
6
hacs.json
Normal file
6
hacs.json
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"name": "groqd",
|
||||
"content_in_root": false,
|
||||
"domains": ["groqd"],
|
||||
"homeassistant": "2024.8.0"
|
||||
}
|
||||
Loading…
Reference in a new issue