fisherman611 commited on
Commit
ed25d6f
·
verified ·
1 Parent(s): 4747f84

Upload 5 files

Browse files
Files changed (5) hide show
  1. .gitignore +180 -0
  2. app.py +567 -0
  3. config.json +76 -0
  4. grammar.txt +49 -0
  5. infer.py +184 -0
.gitignore ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
175
+ # Dataset
176
+ /data
177
+ /checkpoints
178
+ /datatest
179
+ /visualizations
180
+ /testinfer.py
app.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Dict, Any, Tuple
3
+ import torch
4
+ import gradio as gr
5
+ from infer import ModelLoader, DEVICE, Translator
6
+ from models.statistical_mt import LanguageModel
7
+
8
+ # Configure logging
9
+ logging.basicConfig(
10
+ level=logging.INFO,
11
+ format="%(asctime)s - %(levelname)s - %(message)s",
12
+ handlers=[logging.StreamHandler()]
13
+ )
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Store models and tokenizers
17
+ MODELS: Dict[str, Tuple[Any, Any]] = {
18
+ "mbart50": (None, None),
19
+ "mt5": (None, None),
20
+ "rbmt": (None, None),
21
+ "smt": (None, None)
22
+ }
23
+
24
+ def initialize_models(model_types: list[str] = ["mbart50", "mt5", "rbmt", "smt"]) -> None:
25
+ """Initialize translation models and store them in MODELS dictionary.
26
+
27
+ Args:
28
+ model_types: List of model types to initialize.
29
+ """
30
+ global MODELS
31
+ for model_type in model_types:
32
+ try:
33
+ if model_type == "mbart50":
34
+ logger.info("Loading MBart50 model...")
35
+ MODELS["mbart50"] = ModelLoader.load_mbart50()
36
+ logger.info(f"MBart50 model loaded on {DEVICE}")
37
+ elif model_type == "mt5":
38
+ logger.info("Loading MT5 model...")
39
+ MODELS["mt5"] = ModelLoader.load_mt5()
40
+ logger.info(f"MT5 model loaded on {DEVICE}")
41
+ elif model_type == "rbmt":
42
+ logger.info("Initializing RBMT...")
43
+ from models.rule_based_mt import TransferBasedMT
44
+ MODELS["rbmt"] = (TransferBasedMT(), None)
45
+ logger.info("RBMT initialized")
46
+ elif model_type == "smt":
47
+ logger.info("Initializing SMT...")
48
+ MODELS["smt"] = (ModelLoader.load_smt(), None)
49
+ logger.info("SMT initialized")
50
+ except Exception as e:
51
+ logger.error(f"Failed to initialize {model_type}: {str(e)}")
52
+ MODELS[model_type] = (None, None)
53
+
54
+ def translate_text(model_type: str, input_text: str) -> str:
55
+ """Translate input text using the selected model.
56
+
57
+ Args:
58
+ model_type: Type of model to use ('rbmt', 'smt', 'mbart50', 'mt5').
59
+ input_text: English text to translate.
60
+
61
+ Returns:
62
+ Translated text or error message.
63
+ """
64
+ try:
65
+ model, tokenizer = MODELS.get(model_type, (None, None))
66
+ if model is None:
67
+ return f"Error: Model '{model_type}' not loaded or not supported."
68
+ if model_type == "rbmt":
69
+ return Translator.translate_rbmt(input_text)
70
+ elif model_type == "smt":
71
+ return Translator.translate_smt(input_text, model)
72
+ elif model_type == "mbart50":
73
+ return Translator.translate_mbart50(input_text, model, tokenizer)
74
+ else: # mt5
75
+ return Translator.translate_mt5(input_text, model, tokenizer)
76
+ except Exception as e:
77
+ return f"Error during translation: {str(e)}"
78
+
79
+ # Initialize models before launching the app
80
+ logger.info("Starting model initialization...")
81
+ initialize_models()
82
+ logger.info("Model initialization complete.")
83
+
84
+ # Define Gradio interface
85
+ with gr.Blocks(
86
+ theme="soft",
87
+ title="English to Vietnamese Translator",
88
+ css="""
89
+ /* Root variables for consistent theming */
90
+ :root {
91
+ --primary-color: #2563eb;
92
+ --primary-hover: #1d4ed8;
93
+ --secondary-color: #64748b;
94
+ --success-color: #10b981;
95
+ --error-color: #ef4444;
96
+ --warning-color: #f59e0b;
97
+ --background-primary: #ffffff;
98
+ --background-secondary: #f8fafc;
99
+ --background-tertiary: #f1f5f9;
100
+ --text-primary: #1e293b;
101
+ --text-secondary: #64748b;
102
+ --border-color: #e2e8f0;
103
+ --border-radius: 12px;
104
+ --shadow-sm: 0 1px 2px 0 rgba(0, 0, 0, 0.05);
105
+ --shadow-md: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06);
106
+ --shadow-lg: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05);
107
+ --transition: all 0.2s cubic-bezier(0.4, 0, 0.2, 1);
108
+ }
109
+
110
+ /* Global styles */
111
+ * {
112
+ box-sizing: border-box;
113
+ }
114
+
115
+ body {
116
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
117
+ line-height: 1.6;
118
+ color: var(--text-primary);
119
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
120
+ min-height: 100vh;
121
+ }
122
+
123
+ /* Main container */
124
+ .gradio-container {
125
+ max-width: 1200px;
126
+ margin: 0 auto;
127
+ padding: 2rem;
128
+ }
129
+
130
+ /* Header styling */
131
+ .header {
132
+ text-align: center;
133
+ margin-bottom: 3rem;
134
+ padding: 2rem;
135
+ background: var(--background-primary);
136
+ border-radius: var(--border-radius);
137
+ box-shadow: var(--shadow-lg);
138
+ backdrop-filter: blur(10px);
139
+ border: 1px solid rgba(255, 255, 255, 0.2);
140
+ }
141
+
142
+ .header h1 {
143
+ font-size: 2.5rem;
144
+ font-weight: 700;
145
+ color: var(--primary-color);
146
+ margin-bottom: 0.5rem;
147
+ text-shadow: 0 2px 4px rgba(37, 99, 235, 0.2);
148
+ position: relative;
149
+ z-index: 1;
150
+ }
151
+
152
+ /* Enhanced gradient text effect for supported browsers */
153
+ @supports (-webkit-background-clip: text) {
154
+ .header h1 {
155
+ background: linear-gradient(135deg, var(--primary-color), #7c3aed, #ec4899, var(--primary-color));
156
+ background-size: 200% 200%;
157
+ -webkit-background-clip: text;
158
+ -webkit-text-fill-color: transparent;
159
+ background-clip: text;
160
+ animation: gradientShift 4s ease-in-out infinite;
161
+ }
162
+ }
163
+
164
+ @keyframes gradientShift {
165
+ 0%, 100% { background-position: 0% 50%; }
166
+ 50% { background-position: 100% 50%; }
167
+ }
168
+
169
+ .header p {
170
+ color: var(--text-secondary);
171
+ font-size: 1.1rem;
172
+ margin: 0;
173
+ }
174
+
175
+ /* Main content container */
176
+ .main-container {
177
+ background: var(--background-primary);
178
+ border-radius: var(--border-radius);
179
+ padding: 2.5rem;
180
+ box-shadow: var(--shadow-lg);
181
+ backdrop-filter: blur(10px);
182
+ border: 1px solid rgba(255, 255, 255, 0.2);
183
+ transition: var(--transition);
184
+ }
185
+
186
+ .main-container:hover {
187
+ box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.1), 0 10px 10px -5px rgba(0, 0, 0, 0.04);
188
+ }
189
+
190
+ /* Model selection styling */
191
+ .model-section {
192
+ margin-bottom: 2rem;
193
+ }
194
+
195
+ .model-label {
196
+ font-weight: 600;
197
+ color: var(--text-primary);
198
+ margin-bottom: 0.5rem;
199
+ display: block;
200
+ }
201
+
202
+ .gr-dropdown {
203
+ border-radius: var(--border-radius) !important;
204
+ border: 2px solid var(--border-color) !important;
205
+ transition: var(--transition) !important;
206
+ background: var(--background-primary) !important;
207
+ }
208
+
209
+ .gr-dropdown:focus-within {
210
+ border-color: var(--primary-color) !important;
211
+ box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1) !important;
212
+ }
213
+
214
+ .gr-dropdown .options {
215
+ background: var(--background-primary) !important;
216
+ border: 1px solid var(--border-color) !important;
217
+ border-radius: var(--border-radius) !important;
218
+ box-shadow: var(--shadow-lg) !important;
219
+ }
220
+
221
+ .gr-dropdown .options .item {
222
+ padding: 0.75rem 1rem !important;
223
+ transition: var(--transition) !important;
224
+ border-radius: 8px !important;
225
+ margin: 0.25rem !important;
226
+ }
227
+
228
+ .gr-dropdown .options .item:hover {
229
+ background-color: var(--background-secondary) !important;
230
+ cursor: pointer;
231
+ transform: translateY(-1px);
232
+ }
233
+
234
+ .gr-dropdown .options .item.selected {
235
+ background-color: var(--primary-color) !important;
236
+ color: white !important;
237
+ }
238
+
239
+ /* Input/Output sections */
240
+ .io-section {
241
+ display: grid;
242
+ grid-template-columns: 1fr 1fr;
243
+ gap: 2rem;
244
+ margin-bottom: 2rem;
245
+ }
246
+
247
+ @media (max-width: 768px) {
248
+ .io-section {
249
+ grid-template-columns: 1fr;
250
+ gap: 1.5rem;
251
+ }
252
+ }
253
+
254
+ .input-section, .output-section {
255
+ background: var(--background-secondary);
256
+ padding: 1.5rem;
257
+ border-radius: var(--border-radius);
258
+ border: 1px solid var(--border-color);
259
+ transition: var(--transition);
260
+ }
261
+
262
+ .input-section:hover, .output-section:hover {
263
+ border-color: var(--primary-color);
264
+ box-shadow: var(--shadow-md);
265
+ }
266
+
267
+ .section-title {
268
+ font-weight: 600;
269
+ color: var(--text-primary);
270
+ margin-bottom: 1rem;
271
+ display: flex;
272
+ align-items: center;
273
+ gap: 0.5rem;
274
+ }
275
+
276
+ .section-title::before {
277
+ content: "";
278
+ width: 4px;
279
+ height: 20px;
280
+ background: var(--primary-color);
281
+ border-radius: 2px;
282
+ }
283
+
284
+ /* Textbox styling */
285
+ .gr-textbox {
286
+ border-radius: var(--border-radius) !important;
287
+ border: 2px solid var(--border-color) !important;
288
+ transition: var(--transition) !important;
289
+ background: var(--background-primary) !important;
290
+ font-size: 1rem !important;
291
+ line-height: 1.5 !important;
292
+ }
293
+
294
+ .gr-textbox:focus {
295
+ border-color: var(--primary-color) !important;
296
+ box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1) !important;
297
+ outline: none !important;
298
+ }
299
+
300
+ .gr-textbox textarea {
301
+ resize: vertical !important;
302
+ min-height: 120px !important;
303
+ }
304
+
305
+ /* Button styling */
306
+ .translate-button {
307
+ background: linear-gradient(135deg, var(--primary-color), #7c3aed) !important;
308
+ color: white !important;
309
+ border: none !important;
310
+ border-radius: var(--border-radius) !important;
311
+ padding: 1rem 2rem !important;
312
+ font-size: 1.1rem !important;
313
+ font-weight: 600 !important;
314
+ cursor: pointer !important;
315
+ transition: var(--transition) !important;
316
+ box-shadow: var(--shadow-md) !important;
317
+ text-transform: uppercase !important;
318
+ letter-spacing: 0.5px !important;
319
+ position: relative !important;
320
+ overflow: hidden !important;
321
+ }
322
+
323
+ .translate-button:hover {
324
+ transform: translateY(-2px) !important;
325
+ box-shadow: var(--shadow-lg) !important;
326
+ }
327
+
328
+ .translate-button:active {
329
+ transform: translateY(0) !important;
330
+ }
331
+
332
+ .translate-button::before {
333
+ content: "";
334
+ position: absolute;
335
+ top: 0;
336
+ left: -100%;
337
+ width: 100%;
338
+ height: 100%;
339
+ background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.2), transparent);
340
+ transition: left 0.5s;
341
+ }
342
+
343
+ .translate-button:hover::before {
344
+ left: 100%;
345
+ }
346
+
347
+ /* Loading animation */
348
+ .loading {
349
+ display: inline-block;
350
+ width: 20px;
351
+ height: 20px;
352
+ border: 3px solid rgba(255, 255, 255, 0.3);
353
+ border-radius: 50%;
354
+ border-top-color: white;
355
+ animation: spin 1s ease-in-out infinite;
356
+ margin-right: 0.5rem;
357
+ }
358
+
359
+ @keyframes spin {
360
+ to { transform: rotate(360deg); }
361
+ }
362
+
363
+ /* Progress bar styling */
364
+ .progress-bar {
365
+ background: var(--primary-color) !important;
366
+ border-radius: 4px !important;
367
+ height: 4px !important;
368
+ }
369
+
370
+ /* Model info cards */
371
+ .model-info {
372
+ display: grid;
373
+ grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
374
+ gap: 1rem;
375
+ margin-top: 2rem;
376
+ padding-top: 2rem;
377
+ border-top: 1px solid var(--border-color);
378
+ }
379
+
380
+ .model-card {
381
+ background: var(--background-secondary);
382
+ padding: 1rem;
383
+ border-radius: var(--border-radius);
384
+ border: 1px solid var(--border-color);
385
+ transition: var(--transition);
386
+ text-align: center;
387
+ }
388
+
389
+ .model-card:hover {
390
+ border-color: var(--primary-color);
391
+ transform: translateY(-2px);
392
+ box-shadow: var(--shadow-md);
393
+ }
394
+
395
+ .model-card h3 {
396
+ color: var(--primary-color);
397
+ margin-bottom: 0.5rem;
398
+ font-size: 1.1rem;
399
+ }
400
+
401
+ .model-card p {
402
+ color: var(--text-secondary);
403
+ font-size: 0.9rem;
404
+ margin: 0;
405
+ }
406
+
407
+ /* Responsive design */
408
+ @media (max-width: 1024px) {
409
+ .gradio-container {
410
+ padding: 1rem;
411
+ }
412
+
413
+ .main-container {
414
+ padding: 1.5rem;
415
+ }
416
+
417
+ .header h1 {
418
+ font-size: 2rem;
419
+ }
420
+ }
421
+
422
+ @media (max-width: 640px) {
423
+ .header {
424
+ padding: 1.5rem;
425
+ margin-bottom: 2rem;
426
+ }
427
+
428
+ .header h1 {
429
+ font-size: 1.8rem;
430
+ }
431
+
432
+ .main-container {
433
+ padding: 1rem;
434
+ }
435
+
436
+ .translate-button {
437
+ width: 100% !important;
438
+ padding: 0.875rem 1.5rem !important;
439
+ }
440
+ }
441
+
442
+ /* Accessibility improvements */
443
+ .sr-only {
444
+ position: absolute;
445
+ width: 1px;
446
+ height: 1px;
447
+ padding: 0;
448
+ margin: -1px;
449
+ overflow: hidden;
450
+ clip: rect(0, 0, 0, 0);
451
+ white-space: nowrap;
452
+ border: 0;
453
+ }
454
+
455
+ /* Focus styles for accessibility */
456
+ *:focus {
457
+ outline: 2px solid var(--primary-color);
458
+ outline-offset: 2px;
459
+ }
460
+
461
+ /* Custom scrollbar */
462
+ ::-webkit-scrollbar {
463
+ width: 8px;
464
+ }
465
+
466
+ ::-webkit-scrollbar-track {
467
+ background: var(--background-secondary);
468
+ }
469
+
470
+ ::-webkit-scrollbar-thumb {
471
+ background: var(--primary-color);
472
+ border-radius: 4px;
473
+ }
474
+
475
+ ::-webkit-scrollbar-thumb:hover {
476
+ background: var(--primary-hover);
477
+ }
478
+ """
479
+ ) as demo:
480
+ # Header section
481
+ with gr.Column(elem_classes=["header"]):
482
+ gr.HTML("""
483
+ <h1>🌐 English to Vietnamese Machine Translation</h1>
484
+ <p>Advanced AI-powered translation with multiple model options</p>
485
+ """)
486
+
487
+ # Main content
488
+ with gr.Column(elem_classes=["main-container"]):
489
+ # Model selection
490
+ with gr.Row(elem_classes=["model-section"]):
491
+ model_choice = gr.Dropdown(
492
+ choices=[
493
+ ("Rule-Based MT (RBMT)", "rbmt"),
494
+ ("Statistical MT (SMT)", "smt"),
495
+ ("MBart50 (Neural)", "mbart50"),
496
+ ("mT5 (Neural)", "mt5")
497
+ ],
498
+ label="🤖 Select Translation Model",
499
+ value="mbart50",
500
+ elem_classes=["gr-dropdown"],
501
+ info="Choose the translation approach that best fits your needs"
502
+ )
503
+
504
+ # Input/Output section
505
+ with gr.Row(elem_classes=["io-section"]):
506
+ with gr.Column(elem_classes=["input-section"]):
507
+ gr.HTML('<div class="section-title">📝 Input Text (English)</div>')
508
+ input_text = gr.Textbox(
509
+ placeholder="Enter your English text here...\n\nExample: Hello, how are you today?",
510
+ lines=6,
511
+ elem_classes=["gr-textbox"],
512
+ show_label=False,
513
+ container=False
514
+ )
515
+
516
+ with gr.Column(elem_classes=["output-section"]):
517
+ gr.HTML('<div class="section-title">🇻🇳 Translation (Vietnamese)</div>')
518
+ output_text = gr.Textbox(
519
+ placeholder="Translation will appear here...",
520
+ lines=6,
521
+ elem_classes=["gr-textbox"],
522
+ interactive=False,
523
+ show_label=False,
524
+ container=False
525
+ )
526
+
527
+ # Translate button
528
+ translate_button = gr.Button(
529
+ "🚀 Translate Text",
530
+ elem_classes=["translate-button"],
531
+ variant="primary",
532
+ size="lg"
533
+ )
534
+
535
+ # Model information cards
536
+ gr.HTML("""
537
+ <div class="model-info">
538
+ <div class="model-card">
539
+ <h3>RBMT</h3>
540
+ <p>Rule-based approach using linguistic rules and dictionaries</p>
541
+ </div>
542
+ <div class="model-card">
543
+ <h3>SMT</h3>
544
+ <p>Statistical model trained on parallel corpora</p>
545
+ </div>
546
+ <div class="model-card">
547
+ <h3>MBart50</h3>
548
+ <p>Facebook's multilingual BART model</p>
549
+ </div>
550
+ <div class="model-card">
551
+ <h3>mT5</h3>
552
+ <p>Google's multilingual T5 transformer</p>
553
+ </div>
554
+ </div>
555
+ """)
556
+
557
+ # Bind the translation function to the button
558
+ translate_button.click(
559
+ fn=translate_text,
560
+ inputs=[model_choice, input_text],
561
+ outputs=output_text,
562
+ show_progress=True
563
+ )
564
+
565
+ # Launch the app
566
+ if __name__ == "__main__":
567
+ demo.launch()
config.json ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mbart50": {
3
+ "args": {
4
+ "warmup_steps": 500,
5
+ "lr_scheduler_type": "cosine",
6
+ "per_device_train_batch_size": 8,
7
+ "per_device_eval_batch_size": 8,
8
+ "num_train_epochs": 3,
9
+ "weight_decay": 0.05,
10
+ "max_len": 128,
11
+ "id": null,
12
+ "initial_learning_rate": 5e-5,
13
+ "model_name": "facebook/mbart-large-50-many-to-many-mmt",
14
+ "src_lang": "en_XX",
15
+ "tgt_lang": "vi_VN",
16
+ "wandb_project": "mbart50-lora-en-vi",
17
+ "output_dir": "checkpoints"
18
+ },
19
+ "lora_config": {
20
+ "r": 16,
21
+ "lora_alpha": 32,
22
+ "target_modules": [
23
+ "q_proj",
24
+ "v_proj",
25
+ "k_proj",
26
+ "o_proj"
27
+ ],
28
+ "lora_dropout": 0.2
29
+ },
30
+ "paths": {
31
+ "checkpoint_path": "checkpoints/best_mbart50",
32
+ "base_model_name": "facebook/mbart-large-50-many-to-many-mmt"
33
+ }
34
+ },
35
+ "mt5": {
36
+ "args": {
37
+ "warmup_steps": 500,
38
+ "lr_scheduler_type": "cosine",
39
+ "per_device_train_batch_size": 8,
40
+ "per_device_eval_batch_size": 8,
41
+ "num_train_epochs": 3,
42
+ "weight_decay": 0.05,
43
+ "max_len": 128,
44
+ "id": null,
45
+ "initial_learning_rate": 5e-5,
46
+ "prefix": "translate English to Vietnamese: ",
47
+ "model_name": "google/mt5-base",
48
+ "wandb_project": "mt5-lora-en-vi",
49
+ "output_dir": "checkpoints"
50
+ },
51
+ "lora_config": {
52
+ "r": 16,
53
+ "lora_alpha": 32,
54
+ "target_modules": [
55
+ "q",
56
+ "v",
57
+ "k",
58
+ "o"
59
+ ],
60
+ "lora_dropout": 0.2
61
+ },
62
+ "paths": {
63
+ "checkpoint_path": "checkpoints/best_mt5",
64
+ "base_model_name": "google/mt5-base"
65
+ }
66
+ },
67
+ "metric_weights": {
68
+ "bleu": 0.3,
69
+ "rouge1": 0.15,
70
+ "rouge2": 0.15,
71
+ "rougeL": 0.1,
72
+ "meteor": 0.1,
73
+ "bertscore": 0.1,
74
+ "comet": 0.1
75
+ }
76
+ }
grammar.txt ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ S -> NP VP | WhQ | YNQ | IMP | S Conj S | S AdvP
2
+
3
+ VP -> V | V NP | V NP PP | V PP | V S | VP Conj VP | Modal VP | V To VP | AdvP VP | VP AdvP | VP NP NP | V AdjP | v PP PP
4
+
5
+ NP -> PRP | N | PRPS N | Det N | Det AdjP N | Det N PP | NP PP | NP RelClause | PropN | Quant N | CD N | Det| NP Conj NP | N PP | N S
6
+
7
+ PP -> P NP | P S
8
+
9
+ WhQ -> WH_Word VP | WH_Word AUX NP VP | WH_Word AUX NP
10
+
11
+ YNQ -> AUX NP VP | BE NP | BE NP Adj | BE NP NP | BE NP PP | DO NP VP | MD NP VP
12
+
13
+ IMP -> V NP | V | V PP
14
+
15
+ RelClause -> WDT VP | WP VP | NP VP | WDT NP VP | WP NP VP | PP WDT VP
16
+
17
+ Det -> DT | PDT | WDT | PPRS | CD | DT | DT DT
18
+
19
+ Adj -> JJ | JJR | JJS
20
+
21
+ Adv -> RB | RBR | RBS | WRB
22
+
23
+ Conj -> CC| IN
24
+
25
+ Modal -> MD
26
+
27
+ To -> TO
28
+
29
+ PropN -> NNP | NNPS
30
+
31
+ Quant -> CD | DT
32
+
33
+ N -> NN | NNS | NNP | NNPS| CD
34
+
35
+ V -> VB | VBD | VBG | VBN | VBP | VBZ
36
+
37
+ P -> IN
38
+
39
+ WH_Word -> WRB | WP | WDT
40
+
41
+ DO -> VBP | VBZ | VBD
42
+
43
+ BE -> VBZ | VBP | VBD | VB | VBN | VBG
44
+
45
+ AUX -> MD | DO | BE | VBP | VBZ | VBD
46
+
47
+ AdvP -> Adv | Adv Adv | AdvP Conj AdvP | PP| AdvP PP
48
+
49
+ AdjP -> Adj | Adv Adj | AdjP Conj AdjP | AdjP PP
infer.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import argparse
5
+ from typing import Tuple, Union, Dict, Any
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ from transformers import (
10
+ MBart50Tokenizer,
11
+ MBartForConditionalGeneration,
12
+ MT5ForConditionalGeneration,
13
+ MT5TokenizerFast,
14
+ )
15
+ from peft import PeftModel, PeftConfig
16
+
17
+ # Add parent directory to sys.path
18
+ sys.path.append(str(Path(__file__).resolve().parent.parent))
19
+ from models.rule_based_mt import TransferBasedMT
20
+ from models.statistical_mt import SMTExtended, LanguageModel
21
+
22
+ # Device configuration
23
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+ # Load configuration once
26
+ with open("config.json", "r") as json_file:
27
+ CONFIG = json.load(json_file)
28
+
29
+
30
+ def parse_arguments() -> argparse.Namespace:
31
+ """Parse command-line arguments."""
32
+ parser = argparse.ArgumentParser(description="English-Vietnamese Machine Translation Inference")
33
+ parser.add_argument(
34
+ "--model_type",
35
+ type=str,
36
+ choices=["rbmt", "smt", "mbart50", "mt5"],
37
+ required=True,
38
+ help="Type of model to use for translation",
39
+ )
40
+ parser.add_argument("--text", type=str, required=True, help="Text to translate")
41
+ return parser.parse_args()
42
+
43
+
44
+ class ModelLoader:
45
+ """Handles loading of translation models."""
46
+
47
+ @staticmethod
48
+ def load_smt() -> None:
49
+ """Load Statistical Machine Translation model."""
50
+ try:
51
+ smt = SMTExtended()
52
+ model_dir = "checkpoints"
53
+ if os.path.exists(model_dir) and os.path.isfile(os.path.join(model_dir, "phrase_table.pkl")):
54
+ print("Loading existing model...")
55
+ smt.load_model()
56
+ else:
57
+ print("Training new smt...")
58
+ stats = smt.train()
59
+ print(f"Training complete: {stats}")
60
+ print("SMT model loaded successfully!")
61
+ return smt
62
+ except Exception as e:
63
+ raise RuntimeError(f"Failed to load SMT model: {str(e)}")
64
+
65
+ @staticmethod
66
+ def load_mbart50() -> Tuple[MBartForConditionalGeneration, MBart50Tokenizer]:
67
+ """Load MBart50 model and tokenizer."""
68
+ try:
69
+ model_config = CONFIG["mbart50"]["paths"]
70
+ model = MBartForConditionalGeneration.from_pretrained(model_config["base_model_name"])
71
+ model = PeftModel.from_pretrained(model, model_config["checkpoint_path"])
72
+ tokenizer = MBart50Tokenizer.from_pretrained(model_config["checkpoint_path"])
73
+ model.eval()
74
+ print("MBart50 loaded successfully!")
75
+ return model.to(DEVICE), tokenizer
76
+ except Exception as e:
77
+ raise RuntimeError(f"Failed to load MBart50 model: {str(e)}")
78
+
79
+ @staticmethod
80
+ def load_mt5() -> Tuple[MT5ForConditionalGeneration, MT5TokenizerFast]:
81
+ """Load MT5 model and tokenizer."""
82
+ try:
83
+ model_config = CONFIG["mt5"]["paths"]
84
+ model = MT5ForConditionalGeneration.from_pretrained(model_config["base_model_name"])
85
+ model = PeftModel.from_pretrained(model, model_config["checkpoint_path"])
86
+ tokenizer = MT5TokenizerFast.from_pretrained(model_config["checkpoint_path"])
87
+ model.eval()
88
+ print("MT5 loaded successfully!")
89
+ return model.to(DEVICE), tokenizer
90
+ except Exception as e:
91
+ raise RuntimeError(f"Failed to load MT5 model: {str(e)}")
92
+
93
+
94
+ class Translator:
95
+ """Handles translation using different models."""
96
+
97
+ @staticmethod
98
+ def translate_rbmt(text: str) -> str:
99
+ """Translate using Rule-Based Machine Translation."""
100
+ try:
101
+ return TransferBasedMT().translate(text)
102
+ except Exception as e:
103
+ raise RuntimeError(f"RBMT translation failed: {str(e)}")
104
+
105
+ @staticmethod
106
+ def translate_smt(text: str, smt) -> str:
107
+ """Translate using Statistical Machine Translation."""
108
+ try:
109
+ return smt.translate_sentence(text)
110
+ translation = smt.infer(text)
111
+ return translation
112
+ except Exception as e:
113
+ raise RuntimeError(f"SMT translation failed: {str(e)}")
114
+
115
+ @staticmethod
116
+ def translate_mbart50(
117
+ text: str, model: MBartForConditionalGeneration, tokenizer: MBart50Tokenizer
118
+ ) -> str:
119
+ """Translate using MBart50 model with batch processing."""
120
+ try:
121
+ model_config = CONFIG["mbart50"]["args"]
122
+ tokenizer.src_lang = model_config["src_lang"]
123
+ inputs = tokenizer([text], return_tensors="pt", padding=True)
124
+ inputs = {key: value.to(DEVICE) for key, value in inputs.items()}
125
+
126
+ with torch.no_grad(): # Disable gradient computation for inference
127
+ translated_tokens = model.generate(
128
+ input_ids=inputs["input_ids"],
129
+ attention_mask=inputs["attention_mask"],
130
+ forced_bos_token_id=tokenizer.lang_code_to_id[model_config["tgt_lang"]],
131
+ max_length=128,
132
+ num_beams=5,
133
+ )
134
+ return tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
135
+ except Exception as e:
136
+ raise RuntimeError(f"MBart50 translation failed: {str(e)}")
137
+
138
+ @staticmethod
139
+ def translate_mt5(
140
+ text: str, model: MT5ForConditionalGeneration, tokenizer: MT5TokenizerFast
141
+ ) -> str:
142
+ """Translate using MT5 model with batch processing."""
143
+ try:
144
+ prefix = CONFIG["mt5"]["args"]["prefix"]
145
+ inputs = tokenizer([prefix + text], return_tensors="pt", padding=True)
146
+ inputs = {key: value.to(DEVICE) for key, value in inputs.items()}
147
+
148
+ with torch.no_grad(): # Disable gradient computation for inference
149
+ translated_tokens = model.generate(
150
+ input_ids=inputs["input_ids"],
151
+ attention_mask=inputs["attention_mask"],
152
+ max_length=128,
153
+ num_beams=5,
154
+ )
155
+ return tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
156
+ except Exception as e:
157
+ raise RuntimeError(f"MT5 translation failed: {str(e)}")
158
+
159
+
160
+ def main():
161
+ """Main function to run translation."""
162
+ args = parse_arguments()
163
+
164
+ try:
165
+ if args.model_type == "rbmt":
166
+ translation = Translator.translate_rbmt(args.text)
167
+ elif args.model_type == "smt":
168
+ smt = ModelLoader.load_smt()
169
+ translation = Translator.translate_smt(args.text, smt)
170
+ elif args.model_type == "mbart50":
171
+ model, tokenizer = ModelLoader.load_mbart50()
172
+ translation = Translator.translate_mbart50(args.text, model, tokenizer)
173
+ else: # mt5
174
+ model, tokenizer = ModelLoader.load_mt5()
175
+ translation = Translator.translate_mt5(args.text, model, tokenizer)
176
+
177
+ print(f"Translation: {translation}")
178
+ except Exception as e:
179
+ print(f"Error: {str(e)}", file=sys.stderr)
180
+ sys.exit(1)
181
+
182
+
183
+ if __name__ == "__main__":
184
+ main()