Skip to main content
Documentation

Custom providers

Connect new backends and devices to existing toolchains.

You don’t always need to create a new toolchain from scratch. If the existing compiler, profiler, or invoker already handles the model format you need, you can register a custom provider that adds support for running on a new backend or device type.

Overview

A provider is a function registered on an existing component that handles execution for a specific provider type. For example, the built-in ONNXRuntimeCompiler has providers for qai-hub and embedl-onnxruntime. You can register additional providers for your own backends.

This is useful when you:

  • Have a custom compilation or profiling tool that works with a supported model format (TFLite, ONNX Runtime, TensorRT).
  • Want to connect existing toolchains to a new cloud service or device type.
  • Need to customize how models are transferred or executed on your hardware.

Registering a provider on an existing component

Use the @Component.provider decorator on the existing component class:

from embedl_hub.core.compile import ONNXRuntimeCompiler
from embedl_hub.core import HubContext
from pathlib import Path
@ONNXRuntimeCompiler.provider("my_custom_backend")
def _compile_on_my_backend(
    ctx: HubContext,
    onnx_path: Path,
    *,
    device: str | None = None,
    input_shape: tuple[int, ...] | None = None,
    calibration_data=None,
    calibration_method=None,
    per_channel: bool = False,
    quantize_io: bool = False,
):
    """Compile ONNX models using my custom backend."""
    # Your compilation logic here
    ...

Creating a device configuration

If your provider needs custom settings (paths, flags, etc.), create a ProviderConfig subclass:

from dataclasses import dataclass, field
from embedl_hub.core.device import ProviderConfig
@dataclass(frozen=True)
class MyBackendConfig(ProviderConfig):
    """Configuration for my custom backend."""
    backend_path: str = "/usr/local/bin/my-backend"
    cli_args: tuple[str, ...] = field(default_factory=tuple)

Then create devices with this config:

from embedl_hub.core.device import Device
from embedl_hub.core.device import SSHCommandRunner, SSHConfig
from embedl_hub.core.device import DeviceSpec
def get_my_device(
    host: str,
    username: str,
    *,
    name: str = "main",
    config: MyBackendConfig = MyBackendConfig(),
) -> Device:
    runner = SSHCommandRunner(SSHConfig(host=host, username=username))
    return Device(
        name=name,
        runner=runner,
        spec=DeviceSpec(device_name="My Custom Device"),
        provider_type="my_custom_backend",
        provider_config=config,
    )

Accessing the configuration inside a provider

Inside your provider function, retrieve the typed configuration from the device:

@ONNXRuntimeCompiler.provider("my_custom_backend")
def _compile_on_my_backend(ctx, onnx_path, *, device=None, **kwargs):
    dev = ctx.devices[device or "main"]
    cfg = dev.get_provider_config(MyBackendConfig, ONNXRuntimeCompiler)
    backend_path = cfg.backend_path if cfg else "/usr/local/bin/my-backend"
    # Use the runner to execute commands on the device
    result = dev.runner.run([backend_path, str(onnx_path)])
    ...

Per-component configuration overrides

You can provide different configurations for different components on the same device. For example, using a different backend path for compilation vs. profiling:

from embedl_hub.core.compile import ONNXRuntimeCompiler
from embedl_hub.core.profile import ONNXRuntimeProfiler
overrides = dict()
overrides[ONNXRuntimeProfiler] = MyBackendConfig(
    backend_path="/usr/bin/profile-tool",
)
device = Device(
    name="my-device",
    runner=runner,
    spec=DeviceSpec(device_name="My Device"),
    provider_type="my_custom_backend",
    provider_config=MyBackendConfig(backend_path="/usr/bin/compile-tool"),
    provider_config_overrides=overrides,
)

Provider type strings

Provider types are plain strings. The built-in ProviderType enum provides well-known values (LOCAL, QAI_HUB, AWS, EMBEDL_ONNXRUNTIME, TRTEXEC), but any string works:

from embedl_hub.core.component import ProviderType
# Using a built-in type
@MyCompiler.provider(ProviderType.LOCAL)
def _local(ctx, model_path, **kwargs): ...
# Using a custom string
@MyCompiler.provider("my_custom_backend")
def _custom(ctx, model_path, **kwargs): ...

Example: Adding SSH support to an existing compiler

Here’s a full example that adds a new SSH-based provider to the existing TFLiteCompiler:

from pathlib import Path
from embedl_hub.core.compile import TFLiteCompiler
from embedl_hub.core import HubContext
@TFLiteCompiler.provider("my_tflite_device")
def _compile_tflite_ssh(
    ctx: HubContext,
    onnx_path: Path,
    *,
    device: str | None = None,
    input_shape: tuple[int, ...] | None = None,
    calibration_data=None,
    quantize_io: bool = False,
):
    dev = ctx.devices[device or "main"]
    # Transfer model to device
    remote_path = "/tmp/" + onnx_path.name
    dev.runner.put(onnx_path, remote_path)
    # Run compilation on the device
    output_remote = "/tmp/" + onnx_path.stem + ".tflite"
    dev.runner.run(["my-tflite-compiler", remote_path, "-o", output_remote])
    # Fetch result
    local_output = ctx.artifact_dir / (onnx_path.stem + ".tflite")
    dev.runner.get(output_remote, local_output)
    # Log and return
    if ctx.client is not None:
        ctx.client.log_artifact(onnx_path, name="input")
        ctx.client.log_artifact(local_output, name="path")
        from embedl_hub.core.compile import TFLiteCompiledModel
        return TFLiteCompiledModel.from_current_run(ctx)

Next steps