Source code for sign_language_translator.models._utils

"""Extra Utility functions models placed here to avoid circular imports.

This module contains various utility functions and classes to assist with models.

Functions:
    get_model(model_code: str, sign_language=None, text_language=None, video_feature_model=None):
        Get the model based on the provided model code and optional parameters.
"""

from __future__ import annotations

__all__ = [
    "get_model",
]

from typing import TYPE_CHECKING, Union

from sign_language_translator.config.assets import Assets
from sign_language_translator.config.enums import (
    ModelCodeGroups,
    ModelCodes,
    normalize_short_code,
)

if TYPE_CHECKING:
    from enum import Enum


[docs] def get_model(model_code: Union[str, Enum], *args, **kwargs): """ Get the model based on the provided model code and optional parameters. See sign_language_translator.config.enums.ModelCodes (or slt.ModelCodes) for a list of supported model codes. Args: model_code (str): The code representing the desired model. Returns: Any: The instantiated model object if successful, or None if no model found. Raises: ValueError: If inappropriate argument values are provided for text_language, sign_language, or video_feature_model. """ model_code = normalize_short_code(model_code) if model_code == ModelCodes.CONCATENATIVE_SYNTHESIS.value: from sign_language_translator.models import ConcatenativeSynthesis # TODO: validate arg types return ConcatenativeSynthesis(*args, **kwargs) if model_code in ModelCodeGroups.ALL_NGRAM_LANGUAGE_MODELS.value: from sign_language_translator.models import NgramLanguageModel Assets.download( f"models/{model_code}", progress_bar=True, leave=False, chunk_size=1048576 ) return NgramLanguageModel.load(Assets.get_path(f"models/{model_code}")[0]) if model_code in ModelCodeGroups.ALL_MIXER_LANGUAGE_MODELS.value: from sign_language_translator.models import MixerLM Assets.download( f"models/{model_code}", progress_bar=True, leave=False, chunk_size=1048576 ) return MixerLM.load(Assets.get_path(f"models/{model_code}")[0]) if model_code in ModelCodeGroups.ALL_TRANSFORMER_LANGUAGE_MODELS.value: from sign_language_translator.models import TransformerLanguageModel Assets.download( f"models/{model_code}", progress_bar=True, leave=False, chunk_size=1048576 ) return TransformerLanguageModel.load(Assets.get_path(f"models/{model_code}")[0]) if model_code in ModelCodeGroups.ALL_MEDIAPIPE_EMBEDDING_MODELS.value: from sign_language_translator.models import MediaPipeLandmarksModel parts = model_code.split("-") pose_version = int(parts[parts.index("pose") + 1]) # hand_version = int(parts[parts.index("hand") + 1]) names = ["lite", "full", "heavy"] return MediaPipeLandmarksModel( pose_model_name=f"pose_landmarker_{names[pose_version]}.task", # hand_model_name=f"hand_landmarker_{names[hand_version]}.task", # number_of_persons=1, ) if model_code in ModelCodeGroups.ALL_VECTOR_LOOKUP_MODELS.value: from sign_language_translator.models import VectorLookupModel asset_id = f"models/{model_code}" Assets.download(asset_id, progress_bar=True, leave=False, chunk_size=1048576) return VectorLookupModel.load(Assets.get_path(asset_id)[0]) return None