"""NLLB translation provider implementation.""" import logging from typing import Dict, List, Optional from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from ..base.translation_provider_base import TranslationProviderBase from ...domain.exceptions import TranslationFailedException logger = logging.getLogger(__name__) class NLLBTranslationProvider(TranslationProviderBase): """NLLB-200-3.3B translation provider implementation.""" # NLLB language code mappings LANGUAGE_MAPPINGS = { 'en': 'eng_Latn', 'zh': 'zho_Hans', 'zh-cn': 'zho_Hans', 'zh-tw': 'zho_Hant', 'es': 'spa_Latn', 'fr': 'fra_Latn', 'de': 'deu_Latn', 'ja': 'jpn_Jpan', 'ko': 'kor_Hang', 'ar': 'arb_Arab', 'hi': 'hin_Deva', 'pt': 'por_Latn', 'ru': 'rus_Cyrl', 'it': 'ita_Latn', 'nl': 'nld_Latn', 'pl': 'pol_Latn', 'tr': 'tur_Latn', 'sv': 'swe_Latn', 'da': 'dan_Latn', 'no': 'nor_Latn', 'fi': 'fin_Latn', 'el': 'ell_Grek', 'he': 'heb_Hebr', 'th': 'tha_Thai', 'vi': 'vie_Latn', 'id': 'ind_Latn', 'ms': 'zsm_Latn', 'tl': 'tgl_Latn', 'uk': 'ukr_Cyrl', 'cs': 'ces_Latn', 'sk': 'slk_Latn', 'hu': 'hun_Latn', 'ro': 'ron_Latn', 'bg': 'bul_Cyrl', 'hr': 'hrv_Latn', 'sr': 'srp_Cyrl', 'sl': 'slv_Latn', 'et': 'est_Latn', 'lv': 'lvs_Latn', 'lt': 'lit_Latn', 'mt': 'mlt_Latn', 'ga': 'gle_Latn', 'cy': 'cym_Latn', 'is': 'isl_Latn', 'mk': 'mkd_Cyrl', 'sq': 'sqi_Latn', 'eu': 'eus_Latn', 'ca': 'cat_Latn', 'gl': 'glg_Latn', 'ast': 'ast_Latn', 'oc': 'oci_Latn', 'br': 'bre_Latn', 'co': 'cos_Latn', 'sc': 'srd_Latn', 'rm': 'roh_Latn', 'fur': 'fur_Latn', 'lij': 'lij_Latn', 'vec': 'vec_Latn', 'pms': 'pms_Latn', 'lmo': 'lmo_Latn', 'nap': 'nap_Latn', 'scn': 'scn_Latn', 'wa': 'wln_Latn', 'frp': 'frp_Latn', 'gsw': 'gsw_Latn', 'bar': 'bar_Latn', 'ksh': 'ksh_Latn', 'lb': 'ltz_Latn', 'li': 'lim_Latn', 'nds': 'nds_Latn', 'pdc': 'pdc_Latn', 'sli': 'sli_Latn', 'vmf': 'vmf_Latn', 'yi': 'yid_Hebr', 'af': 'afr_Latn', 'zu': 'zul_Latn', 'xh': 'xho_Latn', 'st': 'sot_Latn', 'tn': 'tsn_Latn', 'ss': 'ssw_Latn', 'nr': 'nbl_Latn', 've': 'ven_Latn', 'ts': 'tso_Latn', 'sw': 'swh_Latn', 'rw': 'kin_Latn', 'rn': 'run_Latn', 'ny': 'nya_Latn', 'sn': 'sna_Latn', 'yo': 'yor_Latn', 'ig': 'ibo_Latn', 'ha': 'hau_Latn', 'ff': 'fuv_Latn', 'wo': 'wol_Latn', 'bm': 'bam_Latn', 'dyu': 'dyu_Latn', 'ee': 'ewe_Latn', 'tw': 'twi_Latn', 'ak': 'aka_Latn', 'gaa': 'gaa_Latn', 'lg': 'lug_Latn', 'luo': 'luo_Latn', 'ki': 'kik_Latn', 'kam': 'kam_Latn', 'luy': 'luy_Latn', 'mer': 'mer_Latn', 'kln': 'kln_Latn', 'kab': 'kab_Latn', 'ber': 'ber_Latn', 'am': 'amh_Ethi', 'ti': 'tir_Ethi', 'om': 'orm_Latn', 'so': 'som_Latn', 'mg': 'plt_Latn', 'ny': 'nya_Latn', 'bem': 'bem_Latn', 'tum': 'tum_Latn', 'loz': 'loz_Latn', 'lua': 'lua_Latn', 'umb': 'umb_Latn', 'kmb': 'kmb_Latn', 'kg': 'kon_Latn', 'ln': 'lin_Latn', 'sg': 'sag_Latn', 'fon': 'fon_Latn', 'mos': 'mos_Latn', 'dga': 'dga_Latn', 'kbp': 'kbp_Latn', 'nus': 'nus_Latn', 'din': 'din_Latn', 'luo': 'luo_Latn', 'ach': 'ach_Latn', 'teo': 'teo_Latn', 'mdt': 'mdt_Latn', 'knc': 'knc_Latn', 'fuv': 'fuv_Latn', 'kr': 'kau_Latn', 'dje': 'dje_Latn', 'son': 'son_Latn', 'tmh': 'tmh_Latn', 'taq': 'taq_Latn', 'ttq': 'ttq_Latn', 'thv': 'thv_Latn', 'taq': 'taq_Tfng', 'shi': 'shi_Tfng', 'tzm': 'tzm_Tfng', 'rif': 'rif_Latn', 'kab': 'kab_Latn', 'shy': 'shy_Latn', 'ber': 'ber_Latn', 'acm': 'acm_Arab', 'aeb': 'aeb_Arab', 'ajp': 'ajp_Arab', 'apc': 'apc_Arab', 'ars': 'ars_Arab', 'ary': 'ary_Arab', 'arz': 'arz_Arab', 'auz': 'auz_Arab', 'avl': 'avl_Arab', 'ayh': 'ayh_Arab', 'ayn': 'ayn_Arab', 'ayp': 'ayp_Arab', 'bbz': 'bbz_Arab', 'pga': 'pga_Arab', 'shu': 'shu_Arab', 'ssh': 'ssh_Arab', 'fa': 'pes_Arab', 'tg': 'tgk_Cyrl', 'ps': 'pbt_Arab', 'ur': 'urd_Arab', 'sd': 'snd_Arab', 'ks': 'kas_Arab', 'dv': 'div_Thaa', 'ne': 'npi_Deva', 'si': 'sin_Sinh', 'my': 'mya_Mymr', 'km': 'khm_Khmr', 'lo': 'lao_Laoo', 'ka': 'kat_Geor', 'hy': 'hye_Armn', 'az': 'azj_Latn', 'kk': 'kaz_Cyrl', 'ky': 'kir_Cyrl', 'uz': 'uzn_Latn', 'tk': 'tuk_Latn', 'mn': 'khk_Cyrl', 'bo': 'bod_Tibt', 'dz': 'dzo_Tibt', 'ug': 'uig_Arab', 'tt': 'tat_Cyrl', 'ba': 'bak_Cyrl', 'cv': 'chv_Cyrl', 'sah': 'sah_Cyrl', 'tyv': 'tyv_Cyrl', 'kjh': 'kjh_Cyrl', 'alt': 'alt_Cyrl', 'krc': 'krc_Cyrl', 'kum': 'kum_Cyrl', 'nog': 'nog_Cyrl', 'kaa': 'kaa_Cyrl', 'crh': 'crh_Latn', 'gag': 'gag_Latn', 'tr': 'tur_Latn', 'az': 'azb_Arab', 'ku': 'ckb_Arab', 'lrc': 'lrc_Arab', 'mzn': 'mzn_Arab', 'glk': 'glk_Arab', 'fa': 'pes_Arab', 'tg': 'tgk_Cyrl', 'prs': 'prs_Arab', 'haz': 'haz_Arab', 'bal': 'bal_Arab', 'bcc': 'bcc_Arab', 'bgp': 'bgp_Arab', 'bqi': 'bqi_Arab', 'ckb': 'ckb_Arab', 'diq': 'diq_Latn', 'hac': 'hac_Arab', 'kur': 'kmr_Latn', 'lki': 'lki_Arab', 'pnb': 'pnb_Arab', 'ps': 'pbt_Arab', 'sd': 'snd_Arab', 'skr': 'skr_Arab', 'ur': 'urd_Arab', 'wne': 'wne_Arab', 'xmf': 'xmf_Geor', 'ka': 'kat_Geor', 'hy': 'hye_Armn', 'xcl': 'xcl_Armn', 'he': 'heb_Hebr', 'yi': 'yid_Hebr', 'lad': 'lad_Hebr', 'ar': 'arb_Arab', 'mt': 'mlt_Latn', 'ml': 'mal_Mlym', 'kn': 'kan_Knda', 'te': 'tel_Telu', 'ta': 'tam_Taml', 'or': 'ory_Orya', 'as': 'asm_Beng', 'bn': 'ben_Beng', 'gu': 'guj_Gujr', 'pa': 'pan_Guru', 'hi': 'hin_Deva', 'mr': 'mar_Deva', 'ne': 'npi_Deva', 'sa': 'san_Deva', 'mai': 'mai_Deva', 'bho': 'bho_Deva', 'mag': 'mag_Deva', 'sck': 'sck_Deva', 'new': 'new_Deva', 'bpy': 'bpy_Beng', 'ctg': 'ctg_Beng', 'rkt': 'rkt_Beng', 'syl': 'syl_Beng', 'sat': 'sat_Olck', 'kha': 'kha_Latn', 'grt': 'grt_Beng', 'lus': 'lus_Latn', 'mni': 'mni_Beng', 'kok': 'kok_Deva', 'gom': 'gom_Deva', 'sd': 'snd_Deva', 'doi': 'doi_Deva', 'ks': 'kas_Deva', 'ur': 'urd_Arab', 'ps': 'pbt_Arab', 'bal': 'bal_Arab', 'bcc': 'bcc_Arab', 'bgp': 'bgp_Arab', 'brh': 'brh_Arab', 'hnd': 'hnd_Arab', 'lah': 'lah_Arab', 'pnb': 'pnb_Arab', 'pst': 'pst_Arab', 'skr': 'skr_Arab', 'wne': 'wne_Arab', 'si': 'sin_Sinh', 'dv': 'div_Thaa', 'my': 'mya_Mymr', 'shn': 'shn_Mymr', 'mnw': 'mnw_Mymr', 'kac': 'kac_Latn', 'shn': 'shn_Mymr', 'km': 'khm_Khmr', 'lo': 'lao_Laoo', 'th': 'tha_Thai', 'vi': 'vie_Latn', 'cjm': 'cjm_Arab', 'bjn': 'bjn_Latn', 'bug': 'bug_Latn', 'jv': 'jav_Latn', 'mad': 'mad_Latn', 'ms': 'zsm_Latn', 'min': 'min_Latn', 'su': 'sun_Latn', 'ban': 'ban_Latn', 'bbc': 'bbc_Latn', 'btk': 'btk_Latn', 'gor': 'gor_Latn', 'ilo': 'ilo_Latn', 'pag': 'pag_Latn', 'war': 'war_Latn', 'hil': 'hil_Latn', 'bcl': 'bcl_Latn', 'pam': 'pam_Latn', 'tl': 'tgl_Latn', 'ceb': 'ceb_Latn', 'akl': 'akl_Latn', 'bik': 'bik_Latn', 'cbk': 'cbk_Latn', 'krj': 'krj_Latn', 'tsg': 'tsg_Latn', 'zh': 'zho_Hans', 'yue': 'yue_Hant', 'wuu': 'wuu_Hans', 'hsn': 'hsn_Hans', 'nan': 'nan_Hant', 'hak': 'hak_Hant', 'gan': 'gan_Hans', 'cdo': 'cdo_Hant', 'lzh': 'lzh_Hans', 'ja': 'jpn_Jpan', 'ko': 'kor_Hang', 'ain': 'ain_Kana', 'ryu': 'ryu_Kana', 'eo': 'epo_Latn', 'ia': 'ina_Latn', 'ie': 'ile_Latn', 'io': 'ido_Latn', 'vo': 'vol_Latn', 'nov': 'nov_Latn', 'lfn': 'lfn_Latn', 'jbo': 'jbo_Latn', 'tlh': 'tlh_Latn', 'na': 'nau_Latn', 'ch': 'cha_Latn', 'mh': 'mah_Latn', 'gil': 'gil_Latn', 'kos': 'kos_Latn', 'pon': 'pon_Latn', 'yap': 'yap_Latn', 'chk': 'chk_Latn', 'uli': 'uli_Latn', 'wol': 'wol_Latn', 'pau': 'pau_Latn', 'sm': 'smo_Latn', 'to': 'ton_Latn', 'fj': 'fij_Latn', 'ty': 'tah_Latn', 'mi': 'mri_Latn', 'haw': 'haw_Latn', 'rap': 'rap_Latn', 'tvl': 'tvl_Latn', 'niu': 'niu_Latn', 'tkl': 'tkl_Latn', 'bi': 'bis_Latn', 'ho': 'hmo_Latn', 'kg': 'kon_Latn', 'kj': 'kua_Latn', 'rw': 'kin_Latn', 'rn': 'run_Latn', 'sg': 'sag_Latn', 'sn': 'sna_Latn', 'ss': 'ssw_Latn', 'st': 'sot_Latn', 'sw': 'swh_Latn', 'tn': 'tsn_Latn', 'ts': 'tso_Latn', 've': 'ven_Latn', 'xh': 'xho_Latn', 'zu': 'zul_Latn', 'nd': 'nde_Latn', 'nr': 'nbl_Latn', 'ny': 'nya_Latn', 'bm': 'bam_Latn', 'ee': 'ewe_Latn', 'ff': 'fuv_Latn', 'ha': 'hau_Latn', 'ig': 'ibo_Latn', 'ki': 'kik_Latn', 'lg': 'lug_Latn', 'ln': 'lin_Latn', 'mg': 'plt_Latn', 'om': 'orm_Latn', 'rw': 'kin_Latn', 'rn': 'run_Latn', 'sg': 'sag_Latn', 'sn': 'sna_Latn', 'so': 'som_Latn', 'sw': 'swh_Latn', 'ti': 'tir_Ethi', 'tw': 'twi_Latn', 'wo': 'wol_Latn', 'xh': 'xho_Latn', 'yo': 'yor_Latn', 'zu': 'zul_Latn' } def __init__(self, model_name: str = "facebook/nllb-200-3.3B", max_chunk_length: int = 1000): """ Initialize NLLB translation provider. Args: model_name: The NLLB model name to use max_chunk_length: Maximum length for text chunks """ # Build supported languages mapping supported_languages = {} for lang_code in self.LANGUAGE_MAPPINGS.keys(): # For simplicity, assume all languages can translate to all other languages # In practice, you might want to be more specific about supported pairs supported_languages[lang_code] = [ target for target in self.LANGUAGE_MAPPINGS.keys() if target != lang_code ] super().__init__( provider_name="NLLB-200-3.3B", supported_languages=supported_languages ) self.model_name = model_name self.max_chunk_length = max_chunk_length self._tokenizer: Optional[AutoTokenizer] = None self._model: Optional[AutoModelForSeq2SeqLM] = None self._model_loaded = False def _translate_chunk(self, text: str, source_language: str, target_language: str) -> str: """ Translate a single chunk of text using NLLB model. Args: text: The text chunk to translate source_language: Source language code target_language: Target language code Returns: str: The translated text chunk """ try: # Ensure model is loaded self._ensure_model_loaded() # Map language codes to NLLB format source_nllb = self._map_language_code(source_language) target_nllb = self._map_language_code(target_language) logger.info(f"Translating chunk from {source_nllb} to {target_nllb}") # Tokenize with source language specification inputs = self._tokenizer( text, return_tensors="pt", max_length=1024, truncation=True ) # Generate translation with target language specification outputs = self._model.generate( **inputs, forced_bos_token_id=self._tokenizer.convert_tokens_to_ids(target_nllb), max_new_tokens=1024, num_beams=4, early_stopping=True ) # Decode the translation translated = self._tokenizer.decode(outputs[0], skip_special_tokens=True) # Post-process the translation translated = self._postprocess_text(translated) logger.info(f"Chunk translation completed: {len(text)} -> {len(translated)} chars") return translated except Exception as e: self._handle_provider_error(e, "chunk translation") def _ensure_model_loaded(self) -> None: """Ensure the NLLB model and tokenizer are loaded.""" if self._model_loaded: return try: logger.info(f"Loading NLLB model: {self.model_name}") # Load tokenizer self._tokenizer = AutoTokenizer.from_pretrained( self.model_name, src_lang="eng_Latn" # Default source language ) # Load model self._model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) self._model_loaded = True logger.info("NLLB model loaded successfully") except Exception as e: logger.error(f"Failed to load NLLB model: {str(e)}") raise TranslationFailedException(f"Failed to load NLLB model: {str(e)}") from e def _map_language_code(self, language_code: str) -> str: """ Map standard language code to NLLB format. Args: language_code: Standard language code (e.g., 'en', 'zh') Returns: str: NLLB language code (e.g., 'eng_Latn', 'zho_Hans') """ # Normalize language code to lowercase normalized_code = language_code.lower() # Check direct mapping if normalized_code in self.LANGUAGE_MAPPINGS: return self.LANGUAGE_MAPPINGS[normalized_code] # Handle common variations if normalized_code.startswith('zh'): if 'tw' in normalized_code or 'hant' in normalized_code or 'traditional' in normalized_code: return 'zho_Hant' else: return 'zho_Hans' # Default fallback for unknown codes logger.warning(f"Unknown language code: {language_code}, defaulting to English") return 'eng_Latn' def is_available(self) -> bool: """ Check if the NLLB translation provider is available. Returns: bool: True if provider is available, False otherwise """ try: # Try to import required dependencies import transformers import torch # Check if we can load the tokenizer (lightweight check) if not self._model_loaded: try: test_tokenizer = AutoTokenizer.from_pretrained( self.model_name, src_lang="eng_Latn" ) return True except Exception as e: logger.warning(f"NLLB model not available: {str(e)}") return False else: return True except ImportError as e: logger.warning(f"NLLB dependencies not available: {str(e)}") return False def get_supported_languages(self) -> Dict[str, List[str]]: """ Get supported language pairs for NLLB provider. Returns: dict: Mapping of source languages to supported target languages """ return self.supported_languages.copy() def get_model_info(self) -> Dict[str, str]: """ Get information about the loaded model. Returns: dict: Model information """ return { 'provider': self.provider_name, 'model_name': self.model_name, 'model_loaded': str(self._model_loaded), 'supported_language_count': str(len(self.LANGUAGE_MAPPINGS)), 'max_chunk_length': str(self.max_chunk_length) } def set_model_name(self, model_name: str) -> None: """ Set a different NLLB model name. Args: model_name: The new model name to use """ if model_name != self.model_name: self.model_name = model_name self._model_loaded = False self._tokenizer = None self._model = None logger.info(f"Model name changed to: {model_name}") def clear_model_cache(self) -> None: """Clear the loaded model from memory.""" if self._model_loaded: self._tokenizer = None self._model = None self._model_loaded = False logger.info("NLLB model cache cleared")