DyNin commited on
Commit
8e28984
·
verified ·
1 Parent(s): 9a353a0

Upload 17 files

Browse files
Files changed (17) hide show
  1. .gitignore +475 -0
  2. LICENSE +21 -0
  3. LSP-generic.yml +70 -0
  4. LSP-linux.yml +77 -0
  5. LSP_train.py +386 -0
  6. MANIFEST.in +1 -0
  7. README.md +591 -29
  8. SECURITY.md +35 -0
  9. config.json +30 -0
  10. data_config.py +17 -0
  11. data_loader.py +291 -0
  12. demo.py +128 -0
  13. demo_utils.py +91 -0
  14. env.py +8 -0
  15. gradiodemo.py +39 -0
  16. prepro.py +221 -0
  17. requirements.txt +10 -0
.gitignore ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Ignore Visual Studio temporary files, build results, and
2
+ ## files generated by popular Visual Studio add-ons.
3
+ ##
4
+ ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
5
+
6
+ # User-specific files
7
+ *.suo
8
+ *.user
9
+ *.userosscache
10
+ *.sln.docstates
11
+
12
+ # User-specific files (MonoDevelop/Xamarin Studio)
13
+ *.userprefs
14
+
15
+ # Build results
16
+ [Dd]ebug/
17
+ [Dd]ebugPublic/
18
+ [Rr]elease/
19
+ [Rr]eleases/
20
+ x64/
21
+ x86/
22
+ bld/
23
+ [Bb]in/
24
+ [Oo]bj/
25
+ [Ll]og/
26
+
27
+ # Visual Studio 2015/2017 cache/options directory
28
+ .vs/
29
+ # Uncomment if you have tasks that create the project's static files in wwwroot
30
+ #wwwroot/
31
+
32
+ # Visual Studio 2017 auto generated files
33
+ Generated\ Files/
34
+
35
+ # MSTest test Results
36
+ [Tt]est[Rr]esult*/
37
+ [Bb]uild[Ll]og.*
38
+
39
+ # NUNIT
40
+ *.VisualState.xml
41
+ TestResult.xml
42
+
43
+ # Build Results of an ATL Project
44
+ [Dd]ebugPS/
45
+ [Rr]eleasePS/
46
+ dlldata.c
47
+
48
+ # Benchmark Results
49
+ BenchmarkDotNet.Artifacts/
50
+
51
+ # .NET Core
52
+ project.lock.json
53
+ project.fragment.lock.json
54
+ artifacts/
55
+ **/Properties/launchSettings.json
56
+
57
+ # StyleCop
58
+ StyleCopReport.xml
59
+
60
+ # Files built by Visual Studio
61
+ *_i.c
62
+ *_p.c
63
+ *_i.h
64
+ *.ilk
65
+ *.meta
66
+ *.obj
67
+ *.iobj
68
+ *.pch
69
+ *.pdb
70
+ *.ipdb
71
+ *.pgc
72
+ *.pgd
73
+ *.rsp
74
+ *.sbr
75
+ *.tlb
76
+ *.tli
77
+ *.tlh
78
+ *.tmp
79
+ *.tmp_proj
80
+ *.log
81
+ *.vspscc
82
+ *.vssscc
83
+ .builds
84
+ *.pidb
85
+ *.svclog
86
+ *.scc
87
+
88
+ # Chutzpah Test files
89
+ _Chutzpah*
90
+
91
+ # Visual C++ cache files
92
+ ipch/
93
+ *.aps
94
+ *.ncb
95
+ *.opendb
96
+ *.opensdf
97
+ *.sdf
98
+ *.cachefile
99
+ *.VC.db
100
+ *.VC.VC.opendb
101
+
102
+ # Visual Studio profiler
103
+ *.psess
104
+ *.vsp
105
+ *.vspx
106
+ *.sap
107
+
108
+ # Visual Studio Trace Files
109
+ *.e2e
110
+
111
+ # TFS 2012 Local Workspace
112
+ $tf/
113
+
114
+ # Guidance Automation Toolkit
115
+ *.gpState
116
+
117
+ # ReSharper is a .NET coding add-in
118
+ _ReSharper*/
119
+ *.[Rr]e[Ss]harper
120
+ *.DotSettings.user
121
+
122
+ # JustCode is a .NET coding add-in
123
+ .JustCode
124
+
125
+ # TeamCity is a build add-in
126
+ _TeamCity*
127
+
128
+ # DotCover is a Code Coverage Tool
129
+ *.dotCover
130
+
131
+ # AxoCover is a Code Coverage Tool
132
+ .axoCover/*
133
+ !.axoCover/settings.json
134
+
135
+ # Visual Studio code coverage results
136
+ *.coverage
137
+ *.coveragexml
138
+
139
+ # NCrunch
140
+ _NCrunch_*
141
+ .*crunch*.local.xml
142
+ nCrunchTemp_*
143
+
144
+ # MightyMoose
145
+ *.mm.*
146
+ AutoTest.Net/
147
+
148
+ # Web workbench (sass)
149
+ .sass-cache/
150
+
151
+ # Installshield output folder
152
+ [Ee]xpress/
153
+
154
+ # DocProject is a documentation generator add-in
155
+ DocProject/buildhelp/
156
+ DocProject/Help/*.HxT
157
+ DocProject/Help/*.HxC
158
+ DocProject/Help/*.hhc
159
+ DocProject/Help/*.hhk
160
+ DocProject/Help/*.hhp
161
+ DocProject/Help/Html2
162
+ DocProject/Help/html
163
+
164
+ # Click-Once directory
165
+ publish/
166
+
167
+ # Publish Web Output
168
+ *.[Pp]ublish.xml
169
+ *.azurePubxml
170
+ # Note: Comment the next line if you want to checkin your web deploy settings,
171
+ # but database connection strings (with potential passwords) will be unencrypted
172
+ *.pubxml
173
+ *.publishproj
174
+
175
+ # Microsoft Azure Web App publish settings. Comment the next line if you want to
176
+ # checkin your Azure Web App publish settings, but sensitive information contained
177
+ # in these scripts will be unencrypted
178
+ PublishScripts/
179
+
180
+ # NuGet Packages
181
+ *.nupkg
182
+ # The packages folder can be ignored because of Package Restore
183
+ **/[Pp]ackages/*
184
+ # except build/, which is used as an MSBuild target.
185
+ !**/[Pp]ackages/build/
186
+ # Uncomment if necessary however generally it will be regenerated when needed
187
+ #!**/[Pp]ackages/repositories.config
188
+ # NuGet v3's project.json files produces more ignorable files
189
+ *.nuget.props
190
+ *.nuget.targets
191
+
192
+ # Microsoft Azure Build Output
193
+ csx/
194
+ *.build.csdef
195
+
196
+ # Microsoft Azure Emulator
197
+ ecf/
198
+ rcf/
199
+
200
+ # Windows Store app package directories and files
201
+ AppPackages/
202
+ BundleArtifacts/
203
+ Package.StoreAssociation.xml
204
+ _pkginfo.txt
205
+ *.appx
206
+
207
+ # Visual Studio cache files
208
+ # files ending in .cache can be ignored
209
+ *.[Cc]ache
210
+ # but keep track of directories ending in .cache
211
+ !*.[Cc]ache/
212
+
213
+ # Others
214
+ ClientBin/
215
+ ~$*
216
+ *~
217
+ *.dbmdl
218
+ *.dbproj.schemaview
219
+ *.jfm
220
+ *.pfx
221
+ *.publishsettings
222
+ orleans.codegen.cs
223
+
224
+ # Including strong name files can present a security risk
225
+ # (https://github.com/github/gitignore/pull/2483#issue-259490424)
226
+ #*.snk
227
+
228
+ # Since there are multiple workflows, uncomment next line to ignore bower_components
229
+ # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
230
+ #bower_components/
231
+
232
+ # RIA/Silverlight projects
233
+ Generated_Code/
234
+
235
+ # Backup & report files from converting an old project file
236
+ # to a newer Visual Studio version. Backup files are not needed,
237
+ # because we have git ;-)
238
+ _UpgradeReport_Files/
239
+ Backup*/
240
+ UpgradeLog*.XML
241
+ UpgradeLog*.htm
242
+ ServiceFabricBackup/
243
+ *.rptproj.bak
244
+
245
+ # SQL Server files
246
+ *.mdf
247
+ *.ldf
248
+ *.ndf
249
+
250
+ # Business Intelligence projects
251
+ *.rdl.data
252
+ *.bim.layout
253
+ *.bim_*.settings
254
+ *.rptproj.rsuser
255
+
256
+ # Microsoft Fakes
257
+ FakesAssemblies/
258
+
259
+ # GhostDoc plugin setting file
260
+ *.GhostDoc.xml
261
+
262
+ # Node.js Tools for Visual Studio
263
+ .ntvs_analysis.dat
264
+ node_modules/
265
+
266
+ # Visual Studio 6 build log
267
+ *.plg
268
+
269
+ # Visual Studio 6 workspace options file
270
+ *.opt
271
+
272
+ # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
273
+ *.vbw
274
+
275
+ # Visual Studio LightSwitch build output
276
+ **/*.HTMLClient/GeneratedArtifacts
277
+ **/*.DesktopClient/GeneratedArtifacts
278
+ **/*.DesktopClient/ModelManifest.xml
279
+ **/*.Server/GeneratedArtifacts
280
+ **/*.Server/ModelManifest.xml
281
+ _Pvt_Extensions
282
+
283
+ # Paket dependency manager
284
+ .paket/paket.exe
285
+ paket-files/
286
+
287
+ # FAKE - F# Make
288
+ .fake/
289
+
290
+ # JetBrains Rider
291
+ .idea/
292
+ *.sln.iml
293
+
294
+ # CodeRush
295
+ .cr/
296
+
297
+ # Python Tools for Visual Studio (PTVS)
298
+ __pycache__/
299
+ *.pyc
300
+
301
+ # Cake - Uncomment if you are using it
302
+ # tools/**
303
+ # !tools/packages.config
304
+
305
+ # Tabs Studio
306
+ *.tss
307
+
308
+ # Telerik's JustMock configuration file
309
+ *.jmconfig
310
+
311
+ # BizTalk build output
312
+ *.btp.cs
313
+ *.btm.cs
314
+ *.odx.cs
315
+ *.xsd.cs
316
+
317
+ # OpenCover UI analysis results
318
+ OpenCover/
319
+
320
+ # Azure Stream Analytics local run output
321
+ ASALocalRun/
322
+
323
+ # MSBuild Binary and Structured Log
324
+ *.binlog
325
+
326
+ # NVidia Nsight GPU debugger configuration file
327
+ *.nvuser
328
+
329
+ # MFractors (Xamarin productivity tool) working folder
330
+ .mfractor/
331
+
332
+
333
+ # preprocessed data caused by testing
334
+ tests/corpus/*.db
335
+
336
+
337
+ # Initially taken from Github's Python gitignore file
338
+
339
+ # Byte-compiled / optimized / DLL files
340
+ __pycache__/
341
+ .idea/
342
+ *.py[cod]
343
+ *$py.class
344
+
345
+ # C extensions
346
+ *.so
347
+
348
+ # Distribution / packaging
349
+ .DS_Store
350
+ .Python
351
+ build/
352
+ develop-eggs/
353
+ dist/
354
+ downloads/
355
+ eggs/
356
+ .eggs/
357
+ lib/
358
+ lib64/
359
+ parts/
360
+ sdist/
361
+ var/
362
+ wheels/
363
+ *.egg-info/
364
+ .installed.cfg
365
+ *.egg
366
+ MANIFEST
367
+
368
+ # PyInstaller
369
+ # Usually these files are written by a python script from a template
370
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
371
+ *.manifest
372
+ *.spec
373
+
374
+ # Installer logs
375
+ pip-log.txt
376
+ pip-delete-this-directory.txt
377
+
378
+ # Unit test / coverage reports
379
+ htmlcov/
380
+ .tox/
381
+ .nox/
382
+ .coverage
383
+ .coverage.*
384
+ .cache
385
+ nosetests.xml
386
+ coverage.xml
387
+ *.cover
388
+ .hypothesis/
389
+ .pytest_cache/
390
+
391
+ # Translations
392
+ *.mo
393
+ *.pot
394
+
395
+ # Django stuff:
396
+ *.log
397
+ local_settings.py
398
+ db.sqlite3
399
+
400
+ # Flask stuff:
401
+ instance/
402
+ .webassets-cache
403
+
404
+ # Scrapy stuff:
405
+ .scrapy
406
+
407
+ # Sphinx documentation
408
+ docs/_build/
409
+
410
+ # PyBuilder
411
+ target/
412
+
413
+ # Jupyter Notebook
414
+ .ipynb_checkpoints
415
+
416
+ # IPython
417
+ profile_default/
418
+ ipython_config.py
419
+
420
+ # pyenv
421
+ .python-version
422
+
423
+ # celery beat schedule file
424
+ celerybeat-schedule
425
+
426
+ # SageMath parsed files
427
+ *.sage.py
428
+
429
+ # unstaged code
430
+ # run_gpt2.py
431
+
432
+ # Environments
433
+ .env
434
+ .venv
435
+ env/
436
+ venv/
437
+ ENV/
438
+ env.bak/
439
+ venv.bak/
440
+
441
+ # Spyder project settings
442
+ .spyderproject
443
+ .spyproject
444
+
445
+ # Rope project settings
446
+ .ropeproject
447
+
448
+ # mkdocs documentation
449
+ /site
450
+
451
+ # mypy
452
+ .mypy_cache/
453
+ .dmypy.json
454
+ dmypy.json
455
+
456
+ # Pyre type checker
457
+ .pyre/
458
+
459
+ # vscode
460
+ .vscode
461
+
462
+ # TF code
463
+ tensorflow_code
464
+
465
+ # Models
466
+ # models
467
+ models/
468
+ # test.txt
469
+ # dstc/
470
+ # scripts/
471
+ # philly/
472
+ # demo/
473
+ # docker/
474
+ # prepro_v4.py
475
+ # prepro.py
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Microsoft Corporation.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE
LSP-generic.yml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: LSP
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - blas=1.0=mkl
8
+ - ca-certificates=2019.5.15=1
9
+ - certifi=2019.6.16=py36_1
10
+ - cffi=1.12.3
11
+ - cudatoolkit=10.0.130=0
12
+ - freetype=2.9.1
13
+ - intel-openmp=2019.4
14
+ - jpeg=9b
15
+ - libpng=1.6.37
16
+ - libtiff=4.0.10
17
+ - mkl=2019.4
18
+ - mkl-service=2.3.0
19
+ - mkl_fft=1.0.14
20
+ - mkl_random=1.0.2
21
+ - ninja=1.9.0
22
+ - numpy=1.16.5
23
+ - numpy-base=1.16.5
24
+ - olefile=0.46=py36_0
25
+ - openssl=1.1.1d
26
+ - pillow=6.1.0
27
+ - pip=19.2.2=py36_0
28
+ - pycparser=2.19=py36_0
29
+ - python=3.6.9
30
+ - pytorch=1.2.0
31
+ - setuptools=41.0.1=py36_0
32
+ - six=1.12.0=py36_0
33
+ - sqlite=3.29.0
34
+ - tk=8.6.8
35
+ - torchvision=0.4.0=py36_cu100
36
+ - wheel=0.33.4=py36_0
37
+ - xz=5.2.4
38
+ - zlib=1.2.11
39
+ - zstd=1.3.7
40
+ - nltk=3.4.1
41
+ - pip:
42
+ - backcall==0.1.0
43
+ - boto3==1.9.228
44
+ - botocore==1.12.228
45
+ - chardet==3.0.4
46
+ - decorator==4.4.0
47
+ - docutils==0.15.2
48
+ - idna==2.8
49
+ - ipython==7.8.0
50
+ - ipython-genutils==0.2.0
51
+ - jedi==0.15.1
52
+ - jmespath==0.9.4
53
+ - parso==0.5.1
54
+ - pexpect==4.7.0
55
+ - pickleshare==0.7.5
56
+ - prompt-toolkit==2.0.9
57
+ - ptyprocess==0.6.0
58
+ - pygments==2.4.2
59
+ - python-dateutil==2.8.0
60
+ - pytorch-pretrained-bert==0.6.1
61
+ - regex==2019.8.19
62
+ - requests==2.22.0
63
+ - s3transfer==0.2.1
64
+ - tqdm==4.35.0
65
+ - traitlets==4.3.2
66
+ - urllib3==1.25.3
67
+ - wcwidth==0.1.7
68
+ - flashtext==2.7
69
+ prefix: /home/JJteam/anaconda3/envs/LSP
70
+
LSP-linux.yml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: LSP
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - blas=1.0=mkl
8
+ - ca-certificates=2019.5.15=1
9
+ - certifi=2019.6.16=py36_1
10
+ - cffi=1.12.3=py36h2e261b9_0
11
+ - cudatoolkit=10.0.130=0
12
+ - freetype=2.9.1=h8a8886c_1
13
+ - intel-openmp=2019.4=243
14
+ - jpeg=9b=h024ee3a_2
15
+ - libedit=3.1.20181209=hc058e9b_0
16
+ - libffi=3.2.1=hd88cf55_4
17
+ - libgcc-ng=9.1.0=hdf63c60_0
18
+ - libgfortran-ng=7.3.0=hdf63c60_0
19
+ - libpng=1.6.37=hbc83047_0
20
+ - libstdcxx-ng=9.1.0=hdf63c60_0
21
+ - libtiff=4.0.10=h2733197_2
22
+ - mkl=2019.4=243
23
+ - mkl-service=2.3.0=py36he904b0f_0
24
+ - mkl_fft=1.0.14=py36ha843d7b_0
25
+ - mkl_random=1.0.2=py36hd81dba3_0
26
+ - ncurses=6.1=he6710b0_1
27
+ - ninja=1.9.0=py36hfd86e86_0
28
+ - numpy=1.16.5=py36h7e9f1db_0
29
+ - numpy-base=1.16.5=py36hde5b4d6_0
30
+ - olefile=0.46=py36_0
31
+ - openssl=1.1.1d=h7b6447c_1
32
+ - pillow=6.1.0=py36h34e0f95_0
33
+ - pip=19.2.2=py36_0
34
+ - pycparser=2.19=py36_0
35
+ - python=3.6.9=h265db76_0
36
+ - pytorch=1.2.0=py3.6_cuda10.0.130_cudnn7.6.2_0
37
+ - readline=7.0=h7b6447c_5
38
+ - setuptools=41.0.1=py36_0
39
+ - six=1.12.0=py36_0
40
+ - sqlite=3.29.0=h7b6447c_0
41
+ - tk=8.6.8=hbc83047_0
42
+ - torchvision=0.4.0=py36_cu100
43
+ - wheel=0.33.4=py36_0
44
+ - xz=5.2.4=h14c3975_4
45
+ - zlib=1.2.11=h7b6447c_3
46
+ - zstd=1.3.7=h0b5b093_0
47
+ - nltk=3.4.1
48
+ - pip:
49
+ - backcall==0.1.0
50
+ - boto3==1.9.228
51
+ - botocore==1.12.228
52
+ - chardet==3.0.4
53
+ - decorator==4.4.0
54
+ - docutils==0.15.2
55
+ - idna==2.8
56
+ - ipython==7.8.0
57
+ - ipython-genutils==0.2.0
58
+ - jedi==0.15.1
59
+ - jmespath==0.9.4
60
+ - parso==0.5.1
61
+ - pexpect==4.7.0
62
+ - pickleshare==0.7.5
63
+ - prompt-toolkit==2.0.9
64
+ - ptyprocess==0.6.0
65
+ - pygments==2.4.2
66
+ - python-dateutil==2.8.0
67
+ - pytorch-pretrained-bert==0.6.1
68
+ - regex==2019.8.19
69
+ - requests==2.22.0
70
+ - s3transfer==0.2.1
71
+ - tqdm==4.35.0
72
+ - traitlets==4.3.2
73
+ - urllib3==1.25.3
74
+ - wcwidth==0.1.7
75
+ - flashtext==2.7
76
+ prefix: /home/JJteam/anaconda3/envs/LSP
77
+
LSP_train.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+ '''
4
+ * @Desc: train GPT2 from scratch/ fine tuning.
5
+ Modified based on Huggingface GPT-2 implementation
6
+ '''
7
+
8
+ import json
9
+ import os
10
+ import sys
11
+ import argparse
12
+ import logging
13
+ import time
14
+ import tqdm
15
+ import datetime
16
+ import torch
17
+
18
+ import numpy as np
19
+
20
+ from os.path import join
21
+ from torch.distributed import get_rank, get_world_size
22
+
23
+ from lsp_model import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, Adam
24
+ from gpt2_training.train_utils import load_model, boolean_string, set_lr, get_eval_list_same_length
25
+ from gpt2_training.eval_utils import eval_model_loss
26
+
27
+ from data_loader import BucketingDataLoader, DynamicBatchingLoader, DistributedBucketingDataLoader
28
+
29
+
30
+ from gpt2_training.distributed import all_reduce_and_rescale_tensors, all_gather_list
31
+
32
+
33
+ logging.basicConfig(
34
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
35
+ datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
36
+ logger = logging.getLogger(__name__)
37
+
38
+ INF = 100000000
39
+ CACHE_EMPTY_STEP = 10000
40
+ EVAL_STEP = 100000
41
+
42
+ #########################################################################
43
+ # Prepare Parser
44
+ ##########################################################################
45
+
46
+ parser = argparse.ArgumentParser()
47
+ parser.add_argument('--model_name_or_path', type=str,
48
+ help='pretrained model name or path to local checkpoint')
49
+ parser.add_argument("--seed", type=int, default=42)
50
+ parser.add_argument("--max_seq_length", type=int, default=128)
51
+
52
+ parser.add_argument("--skip_eval", action='store_true',
53
+ help='If true, skip evaluation.')
54
+ parser.add_argument("--init_checkpoint", type=str)
55
+ parser.add_argument("--train_input_file", type=str)
56
+ parser.add_argument("--eval_input_file", type=str)
57
+ parser.add_argument("--continue_from", type=int, default=0)
58
+
59
+ parser.add_argument("--train_batch_size", type=int, default=4,
60
+ help="batch size now means per GPU per step")
61
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=2,
62
+ help="to increase effective batch size "
63
+ "and reduce synchronization")
64
+ parser.add_argument("--eval_batch_size", type=int, default=4)
65
+ parser.add_argument("--learning_rate", type=float, default=1e-5)
66
+ parser.add_argument("--num_optim_steps", type=int, default=1000000,
67
+ help="new API specifies num update steps")
68
+ parser.add_argument("--valid_step", type=int, default=10000,
69
+ help="how many optim steps between validations")
70
+ parser.add_argument("--warmup_proportion", type=float, default=0.1)
71
+ parser.add_argument("--warmup_steps", type=int, default=16000)
72
+
73
+ parser.add_argument("--normalize_data", type=boolean_string, default=True)
74
+ parser.add_argument("--fp16", type=boolean_string, default=True)
75
+ parser.add_argument("--lr_schedule", type=str,
76
+ choices=['noam', 'noamwd', 'BERT', 'None'], default='noam')
77
+ parser.add_argument("--loss_scale", type=float, default=0)
78
+ parser.add_argument("--no_token_id", type=boolean_string, default=True)
79
+
80
+ parser.add_argument("--output_dir", type=str)
81
+ parser.add_argument("--log_dir", type=str)
82
+ parser.add_argument('--pbar', type=boolean_string, default=True, help='turn on progress bar')
83
+
84
+ # distributed
85
+ parser.add_argument('--local_rank', type=int, default=-1,
86
+ help='for torch.distributed')
87
+ parser.add_argument('--config', help='JSON config file')
88
+
89
+
90
+ # do normal parsing
91
+ args = parser.parse_args()
92
+
93
+ if args.config is not None:
94
+ # override argparse defaults by config JSON
95
+ opts = json.load(open(args.config))
96
+ for k, v in opts.items():
97
+ if isinstance(v, str):
98
+ # PHILLY ENV special cases
99
+ if 'PHILLY_JOB_DIRECTORY' in v:
100
+ v = v.replace('PHILLY_JOB_DIRECTORY',
101
+ os.environ['PHILLY_JOB_DIRECTORY'])
102
+ elif 'PHILLY_LOG_DIRECTORY' in v:
103
+ v = v.replace('PHILLY_LOG_DIRECTORY',
104
+ os.environ['PHILLY_LOG_DIRECTORY'])
105
+ setattr(args, k, v)
106
+
107
+ # command line should override config JSON
108
+ argv = sys.argv[1:]
109
+ overrides, _ = parser.parse_known_args(argv)
110
+ for k, v in vars(overrides).items():
111
+ if f'--{k}' in argv:
112
+ setattr(args, k, v)
113
+ setattr(args, 'local_rank', overrides.local_rank)
114
+
115
+
116
+ assert args.train_batch_size % args.gradient_accumulation_steps == 0, \
117
+ 'batch size % gradient accumulation steps != 0!'
118
+ args.train_batch_size = (args.train_batch_size
119
+ // args.gradient_accumulation_steps)
120
+ logger.info('train batch size = {}, '
121
+ 'new train batch size (after gradient accumulation) = {}'.format(
122
+ args.train_batch_size*args.gradient_accumulation_steps,
123
+ args.train_batch_size))
124
+
125
+
126
+ if args.local_rank == -1:
127
+ logger.info('CUDA available? {}'.format(str(torch.cuda.is_available())))
128
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
129
+ n_gpu = torch.cuda.device_count()
130
+ args.device, args.n_gpu = device, n_gpu
131
+ else:
132
+ # distributed training
133
+ torch.cuda.set_device(args.local_rank)
134
+ device = torch.device("cuda", args.local_rank)
135
+ # Initializes the distributed backend which will take care of
136
+ # sychronizing nodes/GPUs
137
+ torch.distributed.init_process_group(backend='nccl')
138
+ n_gpu = torch.distributed.get_world_size()
139
+ args.device, args.n_gpu = device, 1
140
+ logger.info("device: {} n_gpu: {}, distributed training: {}, "
141
+ "16-bits training: {}".format(
142
+ device, n_gpu, bool(args.local_rank != -1), args.fp16))
143
+
144
+ np.random.seed(args.seed)
145
+ torch.random.manual_seed(args.seed)
146
+ torch.cuda.manual_seed(args.seed)
147
+ if n_gpu > 0:
148
+ torch.cuda.manual_seed_all(args.seed)
149
+
150
+ timestamp = datetime.datetime.now().strftime('%Y-%m-%d%H%M%S')
151
+ output_dir = join(args.output_dir,
152
+ 'GPT2.{}.{}.{}gpu.{}'.format(args.learning_rate,
153
+ args.train_batch_size, n_gpu,
154
+ timestamp))
155
+ log_dir = args.log_dir if args.log_dir is not None and len(args.log_dir) > 0 else output_dir
156
+ if args.local_rank == -1 or get_rank() == 0:
157
+ os.makedirs(output_dir, exist_ok=True)
158
+
159
+ logger.info('Input Argument Information')
160
+ args_dict = vars(args)
161
+ for a in args_dict:
162
+ logger.info('%-28s %s' % (a, args_dict[a]))
163
+
164
+
165
+ #########################################################################
166
+ # Prepare Data Set
167
+ ##########################################################################
168
+ enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
169
+
170
+ config = GPT2Config.from_json_file(
171
+ join(args.model_name_or_path, 'config.json'))
172
+
173
+ if args.local_rank == -1:
174
+ train_dataloader = BucketingDataLoader(args.train_input_file,
175
+ args.train_batch_size,
176
+ args.max_seq_length)
177
+ else:
178
+ train_dataloader = DistributedBucketingDataLoader(
179
+ get_rank(), get_world_size(),
180
+ args.train_input_file, args.train_batch_size,
181
+ args.max_seq_length)
182
+
183
+ eval_dataloader_loss = DynamicBatchingLoader(
184
+ args.eval_input_file, enc, args.normalize_data,
185
+ args.eval_batch_size, args.max_seq_length)
186
+
187
+ eval_dataloader_gen = get_eval_list_same_length(
188
+ args.eval_input_file, enc, args.eval_batch_size, True)
189
+
190
+
191
+ #########################################################################
192
+ # Prepare Model and Optimizer
193
+ ##########################################################################
194
+ model = load_model(GPT2LMHeadModel(config), args.init_checkpoint,
195
+ args, verbose=True)
196
+ if args.local_rank != -1:
197
+ # when from scratch make sure initial models are the same
198
+ params = [p.data for p in model.parameters()]
199
+ all_reduce_and_rescale_tensors(
200
+ params, float(torch.distributed.get_world_size()))
201
+
202
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
203
+ total_params = sum([np.prod(p.size()) for p in model_parameters])
204
+ logger.info('Number of parameter = {}'.format(total_params))
205
+
206
+ param_optimizer = list(model.named_parameters())
207
+ no_decay = ['bias', 'ln'] # no decay for bias and LayerNorm (ln)
208
+ optimizer_grouped_parameters = [
209
+ {'params': [p for n, p in param_optimizer
210
+ if not any(nd in n for nd in no_decay)],
211
+ 'weight_decay': 0.01},
212
+ {'params': [p for n, p in param_optimizer
213
+ if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
214
+ ]
215
+
216
+ if args.fp16:
217
+ logger.info('in fp16, using FusedAdam')
218
+ try:
219
+ from apex.optimizers import FP16_Optimizer
220
+ from apex.optimizers import FusedAdam
221
+ except ImportError:
222
+ raise ImportError(
223
+ "Please install apex from https://www.github.com/nvidia/apex "
224
+ "to use distributed and fp16 training.")
225
+
226
+ optimizer = FusedAdam(optimizer_grouped_parameters,
227
+ lr=args.learning_rate,
228
+ bias_correction=False,
229
+ max_grad_norm=1.0)
230
+ if args.loss_scale == 0:
231
+ optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True,
232
+ verbose=False)
233
+ else:
234
+ optimizer = FP16_Optimizer(optimizer,
235
+ static_loss_scale=args.loss_scale,
236
+ verbose=False)
237
+ else:
238
+ optimizer = Adam(optimizer_grouped_parameters, args.learning_rate,
239
+ max_grad_norm=1.0)
240
+
241
+ #########################################################################
242
+ # Training !
243
+ ##########################################################################
244
+
245
+ if args.local_rank == -1 or get_rank() == 0:
246
+ train_logger = open(join(log_dir, 'train_log.txt'), 'a+', buffering=1)
247
+ eval_logger = open(join(log_dir, 'eval_log.txt'), 'a+', buffering=1)
248
+ print('epoch,global_step,step,mean_loss,mean_ppl,n_token_real,'
249
+ 'n_token_total,epoch_time', file=train_logger)
250
+ print('epoch,global_step,step,eval_loss,eval_ppl', file=eval_logger)
251
+
252
+ global_step = 0
253
+ step = 0
254
+ epoch = 0
255
+
256
+ if args.continue_from:
257
+ global_step = args.continue_from
258
+ step = global_step*2 - 1
259
+
260
+
261
+ if args.local_rank != -1:
262
+ n_gpu = 1
263
+ if args.local_rank == -1 or get_rank() == 0:
264
+ if args.pbar:
265
+ pbar = tqdm.tqdm(total=args.num_optim_steps, desc=f"training")
266
+ else:
267
+ pbar = None
268
+
269
+ while True:
270
+ model.train()
271
+ (tr_loss, tr_ppl, mean_ppl, nb_tr_examples, nb_tr_steps) = 0.0, 0.0, 0.0, 0, 0
272
+ n_token_real, n_token_total = 0, 0
273
+ train_start_time_epoch = time.time()
274
+ for batch in train_dataloader:
275
+ # activate new training mode
276
+ seq_len = batch[0].shape[1]
277
+ batch = tuple(t.to(device) for t in batch)
278
+ input_ids, position_ids, token_ids, label_ids, *_ = batch
279
+ if args.no_token_id:
280
+ token_ids = None
281
+ loss, ppl = model(input_ids, position_ids, token_ids, label_ids)
282
+
283
+ if n_gpu > 1:
284
+ loss = loss.mean()
285
+ ppl = ppl.mean()
286
+ loss = loss / (args.train_batch_size / input_ids.shape[0])
287
+ if args.fp16:
288
+ optimizer.backward(loss)
289
+ else:
290
+ loss.backward()
291
+
292
+ tr_loss += float(loss.item()) * (args.train_batch_size / input_ids.shape[0])
293
+ nb_tr_examples += input_ids.size(0)
294
+ nb_tr_steps += 1
295
+ mean_loss = tr_loss / nb_tr_steps
296
+ if ppl.item() < INF:
297
+ tr_ppl += ppl.item()
298
+ else:
299
+ tr_ppl += mean_ppl
300
+ mean_ppl = tr_ppl / nb_tr_steps
301
+
302
+ n_token_total += input_ids.shape[0] * input_ids.shape[1]
303
+ n_token_real += (input_ids != 0).sum().item()
304
+
305
+ # gradient update
306
+ step += 1
307
+ if step % args.gradient_accumulation_steps == 0:
308
+ set_lr(optimizer, global_step,
309
+ args.lr_schedule, args.learning_rate,
310
+ args.warmup_steps, args.warmup_proportion,
311
+ config.n_embd, args.num_optim_steps)
312
+
313
+ if args.local_rank != -1:
314
+ grads = [p.grad.data for p in model.parameters()
315
+ if p.requires_grad and p.grad is not None]
316
+ all_reduce_and_rescale_tensors(grads, float(1))
317
+
318
+ optimizer.step()
319
+ optimizer.zero_grad()
320
+ global_step += 1
321
+
322
+ # Print log info to file
323
+ if args.local_rank != -1:
324
+ mean_loss = sum(all_gather_list(mean_loss)) / get_world_size()
325
+ mean_ppl = sum(all_gather_list(mean_ppl)) / get_world_size()
326
+ n_token_real_all_proc = sum(all_gather_list(n_token_real))
327
+ n_token_total_all_proc = sum(all_gather_list(n_token_total))
328
+ else:
329
+ n_token_real_all_proc = n_token_real
330
+ n_token_total_all_proc = n_token_total
331
+
332
+ if args.local_rank == -1 or get_rank() == 0:
333
+ epoch_time = time.time() - train_start_time_epoch
334
+ if pbar is not None:
335
+ pbar.set_postfix_str(
336
+ f"tok/s: {n_token_real_all_proc//epoch_time//1000}k "
337
+ f"ppl: {mean_ppl:.2f} epoch: {epoch}")
338
+ pbar.update(1)
339
+ print('{},{},{},{},{},{},{},{}'.format(
340
+ epoch+1, global_step+1, step+1, mean_loss, mean_ppl,
341
+ n_token_real_all_proc, n_token_total_all_proc, epoch_time),
342
+ file=train_logger)
343
+
344
+ if global_step % args.valid_step == 0:
345
+ if args.local_rank == -1 or get_rank() == 0:
346
+ # only rank 0 process evaluate
347
+ torch.save(
348
+ {k: (v.cpu() if v is not None else None) # save to cpu tensors
349
+ for k, v in model.state_dict().items()},
350
+ join(output_dir,
351
+ f'GP2-pretrain-step-{global_step}.pkl'))
352
+
353
+ eval_loss, eval_ppl = eval_model_loss(
354
+ model, enc, eval_dataloader_loss, epoch, args)
355
+ # enable generation step evaluation for now
356
+ # gen_response = eval_model_generation(
357
+ # model, enc, eval_dataloader_gen, epoch, args)
358
+ '''
359
+ # probably use beam search only for test set
360
+ if False:
361
+ gen_response_beam = eval_model_generation(
362
+ model, enc, eval_dataloader_gen, epoch, args,
363
+ use_beam_search=True, beam_width=3)
364
+ '''
365
+ print('{},{},{},{},{}'.format(
366
+ epoch+1, global_step+1, step+1, eval_loss, eval_ppl),
367
+ file=eval_logger)
368
+ logger.info('current learning rate: '
369
+ + str(optimizer.param_groups[0]['lr']))
370
+ model.train()
371
+ if global_step >= args.num_optim_steps:
372
+ break
373
+
374
+ if (step+1) % CACHE_EMPTY_STEP == 0:
375
+ torch.cuda.empty_cache()
376
+
377
+ if global_step >= args.num_optim_steps:
378
+ break
379
+ epoch += 1
380
+
381
+
382
+ if args.local_rank == -1 or get_rank() == 0:
383
+ if pbar is not None:
384
+ pbar.close()
385
+ train_logger.close()
386
+ eval_logger.close()
MANIFEST.in ADDED
@@ -0,0 +1 @@
 
 
1
+ include LICENSE
README.md CHANGED
@@ -1,16 +1,425 @@
1
- ---
2
- thumbnail: https://huggingface.co/front/thumbnails/dialogpt.png
3
- tags:
4
- - conversational
5
- license: mit
6
- ---
7
 
8
- ## A State-of-the-Art Large-scale Pretrained Response generation model (DialoGPT)
9
 
10
- DialoGPT is a SOTA large-scale pretrained dialogue response generation model for multiturn conversations.
11
- The [human evaluation results](https://github.com/dreasysnail/Dialogpt_dev#human-evaluation) indicate that the response generated from DialoGPT is comparable to human response quality under a single-turn conversation Turing test.
12
- The model is trained on 147M multi-turn dialogue from Reddit discussion thread.
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  * Multi-turn generation examples from an interactive environment:
15
 
16
  |Role | Response |
@@ -22,33 +431,186 @@ The model is trained on 147M multi-turn dialogue from Reddit discussion thread.
22
  |User |This is so difficult ! |
23
  | Bot | You have no idea how hard it is to be a millionaire and happy . There is a reason the rich have a lot of money |
24
 
25
- Please find the information about preprocessing, training and full details of the DialoGPT in the [original DialoGPT repository](https://github.com/microsoft/DialoGPT)
 
 
 
 
 
 
 
 
26
 
27
- ArXiv paper: [https://arxiv.org/abs/1911.00536](https://arxiv.org/abs/1911.00536)
28
 
29
- ### How to use
30
 
31
- Now we are ready to try out how the model works as a chatting partner!
32
 
33
- ```python
34
- from transformers import AutoModelForCausalLM, AutoTokenizer
35
- import torch
36
 
 
37
 
38
- tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
39
- model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # Let's chat for 5 lines
42
- for step in range(5):
43
- # encode the new user input, add the eos_token and return a tensor in Pytorch
44
- new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')
45
 
46
- # append the new user input tokens to the chat history
47
- bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- # generated a response while limiting the total chat history to 1000 tokens,
50
- chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
51
 
52
- # pretty print last ouput tokens from bot
53
- print("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))
 
 
 
 
 
 
 
 
 
 
 
 
54
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A State-of-the-Art Large-scale Pretrained Response Generation Model (DialoGPT)
 
 
 
 
 
2
 
3
+ ## This project page is no longer maintained as DialoGPT is superseded by [GODEL](https://github.com/microsoft/GODEL), which outperforms DialoGPT according to the results of [this paper](https://arxiv.org/pdf/2206.11309.pdf). Unless you use DialoGPT for reproducibility reasons, we highly recommend you switch to [GODEL](https://github.com/microsoft/GODEL).
4
 
5
+ This repository contains the source code and trained model for a large-scale pretrained dialogue response generation model. The [human evaluation results](#human_eval) indicate that the response generated from DialoGPT is comparable to human response quality under a single-turn conversation Turing test.
 
 
6
 
7
+ <!--See more details on our [project page](https://www.microsoft.com/en-us/research/project/large-scale-pretraining-for-response-generation/)-->
8
+
9
+ The repository is based on [huggingface pytorch-transformer](https://github.com/huggingface/transfer-learning-conv-ai) and [OpenAI GPT-2](https://github.com/openai/gpt-2), containing data extraction script, model training code and pretrained small (117M) medium (345M) and large (762M) model checkpoint.
10
+
11
+ The model is trained on 147M multi-turn dialogue from Reddit discussion thread. The largest model can be trained in several hours on a 8 V100 machines (however this is not required), with distributed training and FP16 option.
12
+
13
+ The include script can be used to reproduce the results of DSTC-7 grounded dialogue generation challenge and a 6k multi-reference dataset created from Reddit data.
14
+
15
+ Project webpage: [https://www.microsoft.com/en-us/research/project/large-scale-pretraining-for-response-generation/](https://www.microsoft.com/en-us/research/project/large-scale-pretraining-for-response-generation/)
16
+
17
+ ArXiv paper: [https://arxiv.org/abs/1911.00536](https://arxiv.org/abs/1911.00536)
18
+
19
+
20
+ ## News ##
21
+
22
+ ***(Update 07/09/2022) Changes on the files.pushshift.io/reddit server caused our data generation pipeline to break. These problems have now been fixed, and the steps explained in the Data Preparation subsection below should work again. Data is generated in about 10 hours with 8 processes (`-j 8`), and 800GB of temporary disk space is needed.***
23
+
24
+ ***(Update 06/23/2021) We have released a retrieval-augmented/grounded version of DialoGPT (RetGen), please check out the [RetGen repo](https://github.com/dreasysnail/RetGen) and [RetGen paper](https://arxiv.org/abs/2105.06597)***
25
+
26
+ ***(Update 05/20/2021) An awesome [video walkthrough](https://www.youtube.com/watch?v=Zo679MYoJns) on YouTube for DialoGPT by [Prakhar Mishra](http://wsl.iiitb.ac.in/prakhar-mishra/)***
27
+
28
+ ***(Update 03/31/2021) A 3rd party demo by [AK391](https://github.com/AK391) using Gradio [web demo](https://gradio.app/g/AK391/DialoGPT) try it out***
29
+
30
+
31
+ ***(Update 09/15/2020) A set of large-scale [dialog ranking models](https://github.com/golsun/DialogRPT) has been released!***
32
+
33
+ DialoGPT generation is improved by integrating with our latest dialog ranking models, [DialogRPT](https://github.com/golsun/DialogRPT)
34
+
35
+ ***(Update 07/08/2020) The 6K multi-ref test set has been released!***
36
+
37
+ To generate the data, pleaser run `demo.py` and set the data option to 'full', the generated 6k multi-ref test set will be located at
38
+
39
+ `./data/test.refs.txt`
40
+
41
+ ***(Update 03/10/2020) Model cards available in Huggingface Transformers!***
42
+
43
+ Please check out our model cards in huggingface Transformers repository. With several lines of code it should be pretty straighforward to play with the DialoGPT interactively.
44
+
45
+ [small model: https://huggingface.co/microsoft/DialoGPT-small](https://huggingface.co/microsoft/DialoGPT-small)
46
+
47
+ [medium model: https://huggingface.co/microsoft/DialoGPT-medium](https://huggingface.co/microsoft/DialoGPT-medium)
48
+
49
+ [large model: https://huggingface.co/microsoft/DialoGPT-large](https://huggingface.co/microsoft/DialoGPT-large)
50
+
51
+ [**(New)** Ranking model: https://huggingface.co/microsoft/DialogRPT-updown](https://huggingface.co/microsoft/DialogRPT-updown?text=I+love+NLP%21+%3C%7Cendoftext%7C%3E+Me+too%21)
52
+
53
+
54
+ ***(Update 01/06/2020) Some third-party decoding script implementations:***
55
+
56
+ - [https://github.com/polakowo/gpt2bot](https://github.com/polakowo/gpt2bot) GPT2Bot implementation based on telegram by polakowo, [ref](https://github.com/microsoft/DialoGPT/issues/3#issuecomment-573904419)
57
+ - [https://colab.research.google.com/drive/1PslHE4Rl4RqSa20s7HEp0ZKITBir6ezE](https://colab.research.google.com/drive/1PslHE4Rl4RqSa20s7HEp0ZKITBir6ezE) A colab interactive notebook by qywu,[ref](https://github.com/microsoft/DialoGPT/issues/3#issuecomment-551410203)
58
+ - [https://github.com/andreamad8/DialoGPT2-Interact](https://github.com/andreamad8/DialoGPT2-Interact) An interactive script featuring multiturn chatbot by andreamad8,[ref](https://github.com/microsoft/DialoGPT/issues/3#issuecomment-551450016)
59
+ - [https://github.com/LHolten/DialoGTP-MMI-decoder](https://github.com/LHolten/DialoGTP-MMI-decoder) An MMI implementation by LHolten,[ref](https://github.com/microsoft/DialoGPT/issues/3#issuecomment-558318401)
60
+ - [https://colab.research.google.com/drive/1-_KjlAV3J1IVDw_9KogjKDCzgFY7Jp7E](https://colab.research.google.com/drive/1-_KjlAV3J1IVDw_9KogjKDCzgFY7Jp7E) A colab interactive notebook by illuminascent@Reddit,[ref](https://www.reddit.com/r/MachineLearning/comments/dt5woy/p_dialogpt_state_of_the_art_conversational_model/?st=k530k3oo&sh=f6cd20fd)
61
+ - [https://colab.research.google.com/drive/15wa925dj7jvdvrz8_z3vU7btqAFQLVlG](https://colab.research.google.com/drive/15wa925dj7jvdvrz8_z3vU7btqAFQLVlG) A great tutorial of how to finetune DialoGPT to build a customized bot built by [Rostyslav Neskorozhenyi](https://www.linkedin.com/in/slanj/). [ref](https://towardsdatascience.com/make-your-own-rick-sanchez-bot-with-transformers-and-dialogpt-fine-tuning-f85e6d1f4e30)
62
+ - [https://gradio.app/g/AK391/DialoGPT](https://gradio.app/g/AK391/DialoGPT) A 3rd party demo by [AK391](https://github.com/AK391) using Gradio [web demo](https://gradio.app/g/AK391/DialoGPT)
63
+
64
+ <!--**This github repository will be updated soon. Please stay tuned.**-->
65
+ <!--## Minimal Computational Configurations-->
66
+ ## Recommended Configuration
67
+
68
+ - Linux Ubuntu 16.04
69
+ - GPU with at least 12G memory
70
+
71
+ DialoGPT was developed entirely on **Ubuntu 16.04**, and -- depending on our availability -- we try to provide support if you experience difficulties running the code on the same configuration. However, we are **unable to provide support for other distributions or operating systems**. Portions of the code may run on other UNIX flavors (macOS, Windows subsystem for Linux, Cygwin, etc.), but it is recommended to use Ubuntu for the main training code.
72
+
73
+ The training code can be run on CPU, but it can be slow. We would recommend to use GPU to train and finetune all models. There is no minimal limit of the number of GPUs. However, if using distributed train for multiple GPUs configuration, the speed-up vs the number of GPUs is roughly sub-linear. To simulate the same batchsize when using less GPUs, please use a larger `gradient_accumulation_steps` in model training.
74
+
75
+ The 117M and 345M model can be loaded in a single GPU with 12G memory. The 762M model would require a single GPU that has greater than 16G memory for efficient training. The training speed on a benchmark data with 50M training instances and V100 GPUs:
76
+
77
+ | n\_gpu | epoch time (h) | token/sec |
78
+ |----------------------|--------|--------|
79
+ | 1 | 118 | 10847 |
80
+ | 2 | 62 | 20645 |
81
+ | 4 | 34 | 37647 |
82
+ | 8 | 18 | 71356 |
83
+
84
+ Fine-tuning from our pretrained model on a new dataset typically requires 1-2 epochs.
85
+
86
+
87
+ ## Setup & Installation (TL;DR)
88
+
89
+ We created a demo script `demo.py` to ease the difficulty of the deployment of this system. The `demo.py` contains a pipeline of **model downloading**, data extraction, data preprocessing and model training over a dummy dataset within one commandline.
90
+
91
+
92
+
93
+ #### Train model with Conda Environment
94
+
95
+ Please use the below commandlines to clone, install the requirements and load the Conda environment (Note that the Nvidia CUDA 10.0 developer toolkit is required):
96
+
97
+
98
+ ```bash
99
+ sudo apt-get install -y make wget gzip bzip2 xz-utils zstd sed
100
+ ```
101
+
102
+ ```bash
103
+ git clone https://github.com/microsoft/DialoGPT.git
104
+ cd DialoGPT
105
+ conda env create -f LSP-linux.yml -n LSP
106
+ conda activate LSP
107
+ ```
108
+
109
+ If you run this on an architecture other than Linux, please use `LSP-generic.yml` instead of `LSP-linux.yml` but please note that the generic one is not tested in all platform, so the stablity can not be gauranteed.
110
+ To use fp16 training, please install apex by using commands below
111
+
112
+ ```bash
113
+ conda activate LSP
114
+ git clone https://github.com/NVIDIA/apex
115
+ cd apex
116
+ git reset --hard 3d01e4a0a188cc8df54bc6e44cf5eb40ff6b4cc5
117
+ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
118
+ python3.6 demo.py
119
+ ```
120
+
121
+ #### Train model with Docker environment
122
+ To start, first install the docker and Nvidia-docker from their official repos.
123
+ The image environment for running the code can be loaded as below:
124
+
125
+ *Nvidia-docker v2.**
126
+
127
+ ```bash
128
+ $ docker run --gpus all --ipc=host --rm -it -v $PWD:/workspace --network=host icaruszyz/large-scale-training:dialogpt bash
129
+ ```
130
+ *Nvidia-docker v1.**
131
+
132
+ ```bash
133
+ $ nvidia-docker --rm -it -v $PWD:/workspace --network=host icaruszyz/large-scale-training:dialogpt bash
134
+ ```
135
+
136
+ Inside the docker container, run
137
+
138
+ ```bash
139
+ python demo.py
140
+ ```
141
+
142
+
143
+
144
+ ## Pipeline details
145
+
146
+ This section explains all components in the `demo.py`.
147
+
148
+ #### Data loading
149
+ Before running `demo.py`, you can set *DATA_FOLDER* (default value `./models`) in `demo.py` as the place you want to download all the data and pretrained/fine-tuned models. Then simply run
150
+ ```bash
151
+ python demo.py
152
+ ```
153
+ to
154
+
155
+ * automatically download models and data,
156
+ * prepare raw data into db that is ready to use for the program,
157
+ * generate a training scripts.
158
+
159
+ Note that by default the `demo.py` will use a dummy data, please specify the Reddit training data by using option `--data`. Three options are available:`dummy`,`small` and `full`.
160
+
161
+ ```bash
162
+ python demo.py --data small
163
+ python demo.py --data full
164
+ ```
165
+
166
+ The small Reddit data is around 140MB and the full Reddit data is more than 27GB. You can prepare a cup of coffee when processing with the full Reddit data because **it takes a long time**!
167
+
168
+ To generate the 6k multi-ref test set data, pleaser run `demo.py` and set the data option to 'full', the generation will be located at
169
+
170
+ `./data/test.refs.txt`
171
+
172
+ #### Pretrained model
173
+
174
+ The pretrained and fine-tuned models are available on azure blobstorage.
175
+ Please run/see `demo.py` for more details about how to download/use those models. Or you could download directly by using the links in `demo_utils.py`.
176
+
177
+ #### Preparing data
178
+ First, use the `prepare4db.sh` to convert a tsv data file into the correct format that the following script can recognize.
179
+ The trainig data need to be then processed into a database file with below commandline:
180
+
181
+ ```bash
182
+ python prepro.py --corpus $DATA_PATH
183
+ ```
184
+
185
+
186
+
187
+ #### Using the training script
188
+
189
+ The training script can be used in single GPU or multiple GPU settings (distributed training across multiple GPUs within a single node):
190
+
191
+ ```bash
192
+ python ./LSP_train.py # Single GPU training
193
+ python -m torch.distributed.launch --nproc_per_node=8 ./LSP_train.py # Training on 8 GPUs
194
+ ```
195
+
196
+
197
+ The training script accept several arguments to tweak the training:
198
+
199
+ Argument | Type | Default value | Description
200
+ ---------|------|---------------|------------
201
+ max\_seq\_length | `int` | `128` | Maximum number of tokens for each training instance.
202
+ train\_input\_file | `str` | `""` | Path of the training dataset in a .db format
203
+ eval\_input\_file | `str` | `""` | Path of the validation set in a tsv format
204
+ continue_from | `int` | `0` | Resuming the training after a specified number of steps
205
+ fp16 | `boolean` | `True` | Whether to use 16-bits floating point for model training.
206
+ train\_batch\_size | `int` | `4` | Batch size for training
207
+ valid\_batch\_size | `int` | `4` | Batch size for validation
208
+ gradient\_accumulation\_steps | `int` | `2` | Accumulate gradients on several steps
209
+ learning\_rate | `float` | `1e-5` | Learning rate
210
+ lr\_schedule | `str` | `noam` | Learning rate schedule can be chosen from [`noam`, `noamwd`, `BERT`, `None`]
211
+ num\_optim\_steps | `int` | `1000000` | Number of training optimization steps
212
+ no_token_id | `boolean` | `True` | If set True, using all-zeros token-type embedding.
213
+
214
+
215
+ During the training, two log files will be updated. The `train_log.txt` and `eval_log.txt` contains the model loss, perplexity and training speed (tokens/sec) statistics for the training and dev set.
216
+
217
+ The log file and saved model checkpoint can be found in `./models/output_model`
218
+
219
+ #### Model decoding
220
+ We note that even with properly filtered Reddit dataset, sometimes our model can still generate moderately toxic/inappropriate responses. Due to this reason, we are unable to provide the decoding script at this time (The live demo and decoding script access is upon invitation only now ).
221
+ We are currently still working on a controlled decoding method to prevent this system from toxic generation. Please stay tuned.
222
+
223
+ **See issues [#3](https://github.com/microsoft/DialoGPT/issues/3) and [Reddit discussions](https://www.reddit.com/r/MachineLearning/comments/dt5woy/p_dialogpt_state_of_the_art_conversational_model/) for some discussions on third-party decoding methods.**
224
+
225
+ See below for some third-party decoding methods:
226
+ - [https://github.com/polakowo/gpt2bot](https://github.com/polakowo/gpt2bot) GPT2Bot implementation based on telegram by polakowo, [ref](https://github.com/microsoft/DialoGPT/issues/3#issuecomment-573904419)
227
+ - [https://colab.research.google.com/drive/1PslHE4Rl4RqSa20s7HEp0ZKITBir6ezE](https://colab.research.google.com/drive/1PslHE4Rl4RqSa20s7HEp0ZKITBir6ezE) A colab interactive notebook by qywu,[ref](https://github.com/microsoft/DialoGPT/issues/3#issuecomment-551410203)
228
+ - [https://github.com/andreamad8/DialoGPT2-Interact](https://github.com/andreamad8/DialoGPT2-Interact) An interactive script featuring multiturn chatbot by andreamad8,[ref](https://github.com/microsoft/DialoGPT/issues/3#issuecomment-551450016)
229
+ - [https://github.com/LHolten/DialoGTP-MMI-decoder](https://github.com/LHolten/DialoGTP-MMI-decoder) An MMI implementation by LHolten,[ref](https://github.com/microsoft/DialoGPT/issues/3#issuecomment-558318401)
230
+ - [https://colab.research.google.com/drive/1-_KjlAV3J1IVDw_9KogjKDCzgFY7Jp7E](https://colab.research.google.com/drive/1-_KjlAV3J1IVDw_9KogjKDCzgFY7Jp7E) A colab interactive notebook by illuminascent@Reddit,[ref](https://www.reddit.com/r/MachineLearning/comments/dt5woy/p_dialogpt_state_of_the_art_conversational_model/?st=k530k3oo&sh=f6cd20fd)
231
+ - [https://gradio.app/g/AK391/DialoGPT](https://gradio.app/g/AK391/DialoGPT) A 3rd party demo by [AK391](https://github.com/AK391) using Gradio [web demo](https://gradio.app/g/AK391/DialoGPT)
232
+
233
+ ## Models
234
+
235
+ We release 6 fine-tuned models which can be further fine-tuned on low-resource user-customized dataset. The total parameters in these models range from 117M to 762M, in accord with OpenAI GPT-2 model sizes.
236
+
237
+ | Model | Fine-tuned from GPT-2| Trained from scratch
238
+ |----------------------|--------|--------|
239
+ | DialoGPT 762M model| [\[link\]](https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/large_ft.pkl) [\[huggingface model card\]](https://huggingface.co/microsoft/DialoGPT-large) | [\[link\]](https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/large_fs.pkl) |
240
+ | DialoGPT 345M model| [\[link\]](https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/medium_ft.pkl) [\[huggingface model card\]](https://huggingface.co/microsoft/DialoGPT-medium) | [\[link\]](https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/medium_fs.pkl) |
241
+ | DialoGPT 117M model| [\[link\]](https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/small_ft.pkl) [\[huggingface model card\]](https://huggingface.co/microsoft/DialoGPT-small)| [\[link\]](https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/small_fs.pkl) |
242
+ | DialoGPT 345M model (reverse, for MMI)| [link](https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/small_reverse.pkl) | -|
243
+ | [DialogRPT](https://github.com/golsun/DialogRPT) (**new** ranking models) | [link](https://github.com/golsun/DialogRPT) | -|
244
+
245
+
246
+ The model files can be loaded exactly as the GPT-2 model checkpoints from Huggingface's [Transformers](https://github.com/huggingface/transformers). You can find the corresponding configuration files (`merges.txt`, `config.json`, `vocab.json`) in DialoGPT's repo in `./configs/*`.
247
+
248
+ The reverse model is predicting the source from the target. This model is used for MMI reranking.
249
+
250
+ The [DialogRPT](https://github.com/golsun/DialogRPT) models our recently proposed ranking models used to predict the human feedback (upvotes, replies) of the responses. These models can be used to improve the DialoGPT generation quality (see our [EMNLP paper](https://arxiv.org/abs/2009.06978) for details).
251
+
252
+ ## Retraining full models
253
+
254
+ ### Data Preparation
255
+
256
+ The first step to retrain the full models is to generate the aforementioned 27GB Reddit dataset. This involves downloading full Reddit submission and comments dumps from [https://files.pushshift.io/reddit](https://files.pushshift.io/reddit) and creating intermediate files, which overall require **700GB of local disk space**. Downloading and processing the full data requires about 1-2 days, depending on your (CPU) compute capabilties (e.g., ~24 hours with 8 cores on a recent computer). Assuming you ran the above setup and installation steps (conda activate LSP, etc.), you can create the full dataset by running either:
257
+
258
+ ```
259
+ python demo.py --data full
260
+ ```
261
+ or
262
+ ```
263
+ cd reddit_extractor; SIZE=full make -j 8; cd ..
264
+ ```
265
+
266
+ The former command calls the latter, so the two methods are equivalent. We recommend the former, as the latter is mostly useful if you run into any problem or want to customize any arguments (e.g., the `make` command lets you build only a subset of the data). Note that the downloading phase can be error prone, for example based on your geolocation (firewall, etc.). If the above commands fail to generate `data/train.tsv`, or if that file is not anywhere close to 27GB, it means something went wrong. In that case, you may want to inspect `reddit_extractor/wget-log` and `reddit_extractor/logs/*.log` for any obvious error (e.g., wget unable to download from pushshift.io). If error messages don't make sense to you, feel free to contact us. If so, please be sure to include any error messages gathered from these log files.
267
+
268
+ Training data statistics: the generated training tsv file should be roughly 26.8 GB uncompressed, with 146.8M training instances, 3.87B source tokens, and 2.14B target tokens (including utterance-level 0/1 weights). The resulting train.tsv file should contain 146,846,215 lines.
269
+
270
+
271
+ ### Training
272
+
273
+ We recommand generating the above data using the `demo.py --data full`, as it (1) generates the data, (2) converts it into DB format, and (3) trains a model using `python LSP_train.py`. Please directly edit `demo.py` if you want to customize any of the hyperparameters.
274
+
275
+
276
+ ## Evaluations
277
+
278
+ #### DSTC-7 challenge
279
+
280
+ Our model achieved the state-of-the-art results in [DSTC-7 Challenge response generation task](https://github.com/mgalley/DSTC7-End-to-End-Conversation-Modeling).
281
+
282
+
283
+ | Experiment | NIST2 | NIST4 | BLEU2 | BLEU4 | METEOR | ENT-4 | DIST-1 | DIST-2 | Avg. Len |
284
+ |--------------------|-------|-------|--------|-------|--------|----------|------------|------------|---------|
285
+ | Human response | 2.62 | 2.65 | 12.35% | 3.13% | 8.31% | 10.45 | 16.66% | 67.01% | 18.8 |
286
+ | DSTC-7 Winner | 2.51 | 2.52 | 14.35% | 1.83% | 8.07% | 9.03 | 10.89% | 32.49% | 15.1 |
287
+ | DialoGPT 345M | 2.80 | 2.82 | 14.16% | 2.31% | 8.51% | **10.08** | 9.13% | 39.73% | 16.9 |
288
+ | DialoGPT 345M (BS) | **2.92** | **2.97** | **19.18%** | **6.05%** | **9.29%** | 9.57 | **15.73%** | **51.03%** | 14.2 |
289
+
290
+ where ENT represents the [Entropy score](https://arxiv.org/abs/1809.05972), and DIST represents the [Distinct score](https://arxiv.org/pdf/1510.03055.pdf). For all metrics except the average length, larger are better.
291
+
292
+ <!--| Experiment | NIST1 | NIST2 | NIST3 | NIST4 | BLEU1 | BLEU2 | BLEU3 | BLEU4 | METEOR | ENT-1 | ENT-2 | ENT-3 | ENT-4 | DIST-1 | DIST-2 | Len |
293
+ |----------------------|--------|--------|--------|--------|--------|--------|--------|--------|--------|----------|----------|----------|----------|------------|------------|---------|
294
+ | Human | 2.4237 | 2.6244 | 2.6472 | 2.65 | 0.3408 | 0.1235 | 0.0572 | 0.0313 | 0.0831 | 6.5893 | 9.7423 | 10.4101 | 10.4450 | 0.1666 | 0.6701 | 18.7568 |
295
+ | DSTC-7 Winner | 2.3408 | 2.5102 | 2.522 | 2.523 | 0.4122 | 0.1435 | 0.0501 | 0.0183 | 0.0807 | 5.3832 | 7.6065 | 8.5304 | 9.0298 | 0.1089 | 0.3249 | 15.1327 |
296
+ | DialoGPT | 2.5863 | 2.804 | 2.823 | 2.8246 | 0.3927 | 0.1416 | 0.0555 | 0.0231 | 0.0851 | 5.5791 | 8.5109 | 9.6872 | 10.0765 | 0.0913 | 0.3973 | 16.9484 |
297
+ | DialoGPT(beam search) | **2.5943**| **2.9163** | **2.9624** | **2.9681**| **0.4238** | **0.1918** | **0.1027** | **0.0605** | **0.0929** | **6.0815** | **8.7379** | 9.4037 | 9.5697 | 0.1573 | 0.5103 | 14.1603 |-->
298
+
299
+ Note that the superior automatic evaluation comparing to human responses does not necessary imply that our model achieves human parity. Please check out our paper for more detailed analysis.
300
+
301
+
302
+ To fine-tune the `345M` DialoGPT model on the DSTC-7 challenge data on a server with 8 V100 GPUs, please run the following commandline (The DSTC data can be found at [DSTC-7 repo](https://github.com/mgalley/DSTC7-End-to-End-Conversation-Modeling)):
303
+
304
+ ```bash
305
+ python3 -m torch.distributed.launch --nproc_per_node=8 train_LSP.py --init_checkpoint ./models/medium/medium_ft.pkl --train_input_file ./data/DSTC_train.db --eval_input_file ./data/DSTC_valid.tsv --model_name_or_path ./model/medium/ --learning_rate 1e-4 --train_batch_size 64 --eval_batch_size 64 --no_token_id
306
+ ```
307
+
308
+ The trained model can be found at [DSTC medium model](https://acvrpublicycchen.blob.core.windows.net/dialogpt/DSTC/medium_ft.pkl)
309
+
310
+
311
+ #### Evaluation
312
+
313
+ 1. Please **downloads** the following 3rd-party packages and save into the empty folder `3rdparty`:
314
+ * [**mteval-v14c.pl**](https://goo.gl/YUFajQ) to compute [NIST](http://www.mt-archive.info/HLT-2002-Doddington.pdf). You may need to install the following [perl](https://www.perl.org/get.html) modules (e.g. by `cpan install`): XML:Twig, Sort:Naturally and String:Util.
315
+ * [**meteor-1.5**](http://www.cs.cmu.edu/~alavie/METEOR/download/meteor-1.5.tar.gz) to compute [METEOR](http://www.cs.cmu.edu/~alavie/METEOR/index.html). It requires [Java](https://www.java.com/en/download/help/download_options.xml).
316
+
317
+
318
+ 2. Please follow the [DSTC-7 official repo](https://github.com/mgalley/DSTC7-End-to-End-Conversation-Modeling/tree/master/data_extraction) to extract the data, and put `data-official-test/test.refs.txt` into `./dstc/data/` folder.
319
+
320
+ 3. Run the extraction script below to produce the human response hypothesis file `human.resp.txt`:
321
+
322
+ ```bash
323
+ python extract_human.py
324
+ ```
325
+
326
+ 4. Finally, to reproduce the results of human hypothesis on DSTC dataset, please run following commands under the repo folder:
327
+
328
+ ```bash
329
+ python batch_eval.py
330
+ ```
331
+
332
+ The evaluation results will be generated in the folder `./dstc/eval/`
333
+
334
+
335
+ ## 6K multi-ref dataset result
336
+
337
+ ### Automatic evaluation
338
+
339
+ We test on 6K multi-ref dataset from Reddit. The results are summarized in below
340
+
341
+ | Experiment | NIST2 | NIST4 | BLEU2 | BLEU4 | METEOR | ENT-4 | DIST-1 | DIST-2 | Avg. Len |
342
+ |--------------------|-------|-------|--------|-------|--------|----------|------------|------------|---------|
343
+ | Human response | 3.41 | 4.25 | 17.90% | 7.48% | 10.64% | 11 | 14.50% | 63.00% | 13.1 |
344
+ | DialoGPT 117M | 2.39 | 2.41 | 10.54% | 1.55% | 7.53% | 10.78 | 8.60% | 39.90% | 12.8 |
345
+ | DialoGPT 345M | 3 | 3.06 | 16.96% | 4.56% | 9.81% | 9.13 | 6.80% | 26.30% | 12.2 |
346
+ | DialoGPT 762M | 2.84 | 2.9 | 18.66% | 5.25% | 9.66% | 9.72 | 7.76% | 29.93% | 11.2 |
347
+ | DialoGPT 345M (BS) | **3.4** | **3.5** | **21.76%** | **7.92%** | 10.74% | 10.48 | **12.38%** | **48.74%** | 11.3 |
348
+ | DialoGPT 345M (w/MMI)| 3.28 | 3.33 | 15.68% | 3.94% | **11.23%** | **11.25** | 9.39% | 45.55% | 17.2 |
349
+
350
+ ### <a name="human_eval"></a>Human evaluation
351
+
352
+ We further conduct human evaluations (6K examples for each methods, each example is evaluated by 3 human judges). The results show a strong evidence that our generation quality is towards approaching the quality of real human responses, under this non-interactive Turing test:
353
+
354
+
355
+ *Relevance*: A and B, which one is more relevant to the source prompt.
356
+
357
+ | System A | A Wins (%) | Ties (%) | B Wins (%) | System B|
358
+ |--------------------|-------|-------|--------|-------|
359
+ |DialoGPT 345M|2671 (45%) | 513 (9%) | 2816 (47%)| Human responses|
360
+ |DialoGPT 345M| 3281 (72%)| 394 (9%) | 882 (19%)| [PersonalityChat](https://docs.microsoft.com/en-us/azure/cognitive-services/project-personality-chat/overview)|
361
+ |DialoGPT 345M w/ MMI| **2871** (48%)| 522 (9%) | 2607 (43%)| Human responses|
362
+
363
+ *Informativeness*: A and B, which one is more contentful and informative.
364
+
365
+ | System A | A Wins (%) | Ties (%) | B Wins (%) | System B|
366
+ |--------------------|-------|-------|--------|-------|
367
+ |DialoGPT 345M| 2722 (45%) | 234 (4%) | 3044 (51%)| Human responses|
368
+ |DialoGPT 345M|3490 (77%) | 206 (5%) | 861 (19%)| [PersonalityChat](https://docs.microsoft.com/en-us/azure/cognitive-services/project-personality-chat/overview)|
369
+ |DialoGPT 345M w/ MMI| **3011** (50%)| 234 (4%) | 2755 (46%)| Human responses|
370
+
371
+
372
+ *Human-Like*: A and B, which one do you think is more likely to be generated by Human.
373
+
374
+ | System A | A Wins (%) | Ties (%) | B Wins (%) | System B|
375
+ |--------------------|-------|-------|--------|-------|
376
+ |DialoGPT 345M|2716 (45%) | 263 (4%) | 3021 (50%)| Human responses|
377
+ |DialoGPT 345M|3462 (76%) | 196 (4%) | 899 (20%)| [PersonalityChat](https://docs.microsoft.com/en-us/azure/cognitive-services/project-personality-chat/overview)|
378
+ |DialoGPT 345M w/ MMI| **2978** (50%)| 241 (4%) | 2781 (46%)| Human responses|
379
+
380
+
381
+ Please see full details in our [arxiv paper](https://arxiv.org/abs/1911.00536).
382
+
383
+
384
+
385
+
386
+ <!--Relevance
387
+ System Wins (%) Ties (%) Losses (%)
388
+ 2 vs 1 2671 (0.45) 513 (0.09) 2816 (0.47)
389
+ 2 vs 3 3281 (0.72) 394 (0.09) 882 (0.19)
390
+ 2 vs 4 2379 (0.40) 527 (0.09) 3094 (0.52)
391
+ 2 vs 5 3019 (0.50) 581 (0.10) 2400 (0.40)
392
+ 2 vs 6 2726 (0.45) 576 (0.10) 2698 (0.45)
393
+
394
+ Informativeness
395
+ System Wins (%) Ties (%) Losses (%)
396
+ 2 vs 1 2722 (0.45) 234 (0.04) 3044 (0.51)
397
+ 2 vs 3 3490 (0.77) 206 (0.05) 861 (0.19)
398
+ 2 vs 4 2474 (0.41) 257 (0.04) 3269 (0.54)
399
+ 2 vs 5 3230 (0.54) 362 (0.06) 2408 (0.40)
400
+ 2 vs 6 2856 (0.48) 303 (0.05) 2841 (0.47)
401
+
402
+ Human-Like
403
+ System Wins (%) Ties (%) Losses (%)
404
+ 2 vs 1 2716 (0.45) 263 (0.04) 3021 (0.50)
405
+ 2 vs 3 3462 (0.76) 196 (0.04) 899 (0.20)
406
+ 2 vs 4 2478 (0.41) 289 (0.05) 3233 (0.54)
407
+ 2 vs 5 3233 (0.54) 340 (0.06) 2427 (0.40)
408
+ 2 vs 6 2847 (0.47) 321 (0.05) 2832 (0.47)
409
+ -->
410
+
411
+
412
+ <!--| Experiment | NIST1 | NIST2 | NIST3 | NIST4 | BLEU1 | BLEU2 | BLEU3 | BLEU4 | METEOR | ENT-4 | DIST-1 | DIST-2 |
413
+ |------------------------------|-------|-------|-------|-------|--------|--------|--------|-------|--------|----------|------------|------------|
414
+ | Human response | 2.99 | 3.41 | 3.83 | 4.25 | 39.61% | 17.90% | 10.71% | 7.48% | 10.64% | 11 | 14.50% | 63.00% |
415
+ | DialoGPT 117M | 2.25 | 2.39 | 2.41 | 2.41 | 35.43% | 10.54% | 3.85% | 1.55% | 7.53% | 10.78 | 8.60% | 39.90% |
416
+ | DialoGPT 345M |**2.67** | **3** | **3.05** | **3.06** | 40.97% | 16.96% | 8.31% | 4.56% | **9.81%** | 9.13 | 6.80% | 26.30% |
417
+ | DialoGPT 762M | 2.51 | 2.84 | 2.89 | 2.9 | **44.45%** | **18.66%** | **9.32%** | **5.25%** | 9.66% | **9.72** | **7.76%** | **29.93%** |-->
418
+
419
+ <!--| DialoGPT 762M (beam search)| 2.51 | 2.9 | 2.97 | 2.98 | 44.84% | 21.08% | 12.01% | 7.57% | 10.11% | 10.06 | 11.62% | 44.07% |-->
420
+ <!--## ConvAI challenge -->
421
+ <!--| DialoGPT 345M (beam search) | 2.94 | 3.4 | 3.49 | 3.5 | 45.27% | 21.76% | 12.51% | 7.92% | 10.74% | 10.48 | 12.38% | 48.74% |-->
422
+ ## Generated Examples
423
  * Multi-turn generation examples from an interactive environment:
424
 
425
  |Role | Response |
 
431
  |User |This is so difficult ! |
432
  | Bot | You have no idea how hard it is to be a millionaire and happy . There is a reason the rich have a lot of money |
433
 
434
+ * Multi-turn self-playing dialogue with user prompt
435
+
436
+ |Role | Response |
437
+ |---------|--------|
438
+ |User | what is the meaning of life ?|
439
+ | Bot | The meaning is to be with your family and friends .|
440
+ |Bot | I’m going to guess : It means that your parents and friends have loved you . That’s the meaning , I believe . |
441
+ | Bot | I think the meaning is to make sure you’re alive when you die . So you’ll be with your family , friends , and yourself .|
442
+ |Bot |It’s not , but thanks for trying to help . |
443
 
 
444
 
 
445
 
 
446
 
 
 
 
447
 
448
+ * Commonsense question answering examples
449
 
450
+ |User | Bot |
451
+ |---------|--------|
452
+ |who is the first president of the United States | George Washington |
453
+ |what is the boiling point of water? | I think it's about 212 F. |
454
+ |which one is bigger, sun or moon? | The sun . |
455
+ |which animal has black and white stripes? | A black and white striped cat . |
456
+
457
+ * Handling context that does not appear in training examples (note that training examples are sampled from 2005 to 2017)
458
+
459
+ |User | Bot |
460
+ |---------|--------|
461
+ |The trading war between China and US is still happening . | I'm sure the Chinese government and the U . S . will be fine with that as long as they keep their arms on the table and the money flowing . |
462
+ |who won the world cup in 2018 ? | Germany , I believe . |
463
+ |Nvidia's Titan RTX is really good . | I have this card and i love it . No complaints . |
464
+ |Can Porsche beat Tesla with its new Taycan EV ? | I don't think so . The Tesla is more powerful than the porsche .
465
+
466
+ * Our bot's response to philosophy questions
467
+
468
+ |User | Bot |
469
+ |---------|--------|
470
+ |What is the meaning of a good life ? | I think it's that we're all connected to our past lives and the meaning of life is to live the way you want to and how you want to . |
471
+ |Which is more important -- being successful or being happy ? | Being successful is just endless . Being happy is the only thing that matters . |
472
+ |Is peace the only way to stop war ? | Not if the world is not willing to work with you . Peace is the ultimate goal , but peace requires more than that . |
473
+ | How much freedom should people have ? | The right amount of freedom is the freedom to do as you please , as long as you don't hurt people or property .|
474
+
475
+
476
+ * Finally
477
+
478
+ |User | Bot |
479
+ |---------|--------|
480
+ |What do you think of the DialoGPT repo ? | It's pretty good , but it's very easy to find a bug . |
481
+
482
+ Please start a issue if you spot any :)
483
+
484
+
485
+
486
+ ## Try our system
487
+ The live demo and decoding script access is upon invitation only now. Please stayed tuned for the full release.
488
+
489
+
490
+ <!--The live demo and decoding script access is upon approval request. Please apply [here](https://docs.google.com/spreadsheets/d/1epTNfaqva1isVO_o9pbyhVLsnzDn58dGkcLB0OUVcqs/edit?usp=sharing)-->
491
+
492
+ <!--This model should give a Hits@1 over 79, perplexity of 20.5 and F1 of 16.5 using the convai2 evaluation script (see below).
493
+
494
+ These numbers are slightly lower than the number we obtained in the ConvAI2 competition. Here is what you can tweak to reach the same results:
495
+
496
+ - in the ConvAI2 competition we also used tweaked position emebddings so that the history of the dialog always start at with the same embeddings. This is easy to add with pytorch-pretrained-bert and should improve the hits@1 metric.
497
+ - in the ConvAI2 competition we used a beam search decoder. While the results are better in term of f1 metric, our feeling is that the human experience is les compelling with beam search versus the nucleus sampling detector which is provided in the present repository.-->
498
+
499
+ <!--## Using the interaction script
500
+
501
+ The training script saves all the experiments and checkpoints in a sub-folder named with the timestamp of the experiment in the `./runs` folder of the repository base folder.
502
+
503
+ You can then use the interactive script to interact with the model simply by pointing to this folder.
504
+
505
+ Here is an example command line to run the interactive script:
506
+
507
+ ```bash
508
+ python ./interact.py --model_checkpoint ./data/Apr17_13-31-38_thunder/ # run the interactive script with a training checkpoint
509
+ python ./interact.py # run the interactive script with the finetuned model on our S3
510
+ ```
511
+
512
+ The fine-tuned model will gives FINAL Hits@1: 0.715
513
 
514
+ The interactive script accept a few arguments to tweak the decoding algorithm:
 
 
 
515
 
516
+ Argument | Type | Default value | Description
517
+ ---------|------|---------------|------------
518
+ dataset_path | `str` | `""` | Path or url of the dataset. If empty download from S3.
519
+ dataset_cache | `str` | `'./dataset_cache.bin'` | Path or url of the dataset cache
520
+ model | `str` | `"openai-gpt"` | Path, url or short name of the model
521
+ max_history | `int` | `2` | Number of previous utterances to keep in history
522
+ device | `str` | `cuda` if `torch.cuda.is_available()` else `cpu` | Device (cuda or cpu)
523
+ no_sample | action `store_true` | Set to use greedy decoding instead of sampling
524
+ max_length | `int` | `20` | Maximum length of the output utterances
525
+ min_length | `int` | `1` | Minimum length of the output utterances
526
+ seed | `int` | `42` | Seed
527
+ temperature | `int` | `0.7` | Sampling softmax temperature
528
+ top_k | `int` | `0` | Filter top-k tokens before sampling (`<=0`: no filtering)
529
+ top_p | `float` | `0.9` | Nucleus filtering (top-p) before sampling (`<=0.0`: no filtering)
530
 
531
+ ## Running ConvAI2 evaluation scripts
 
532
 
533
+ To run the evaluation scripts of the ConvAI2 challenge, you first need to install `ParlAI` in the repo base folder like this:
534
+
535
+ ```bash
536
+ git clone https://github.com/facebookresearch/ParlAI.git
537
+ cd ParlAI
538
+ python setup.py develop
539
+ ```
540
+
541
+ You can then run the evaluation script from `ParlAI` base folder:
542
+
543
+ ```bash
544
+ cd ParlAI
545
+ python ../convai_evaluation.py --eval_type hits@1 # to download and evaluate our fine-tuned model on hits@1 metric
546
+ python ../convai_evaluation.py --eval_type hits@1 --model_checkpoint ./data/Apr17_13-31-38_thunder/ # to evaluate a training checkpoint on hits@1 metric
547
  ```
548
+
549
+ The evaluation script accept a few arguments to select the evaluation metric and tweak the decoding algorithm:
550
+
551
+ Argument | Type | Default value | Description
552
+ ---------|------|---------------|------------
553
+ eval_type | `str` | `"hits@1"` | Evaluate the model on `hits@1`, `ppl` or `f1` metric on the ConvAI2 validation dataset
554
+ model | `str` | `"openai-gpt"` | Path, url or short name of the model
555
+ max_history | `int` | `2` | Number of previous utterances to keep in history
556
+ device | `str` | `cuda` if `torch.cuda.is_available()` else `cpu` | Device (cuda or cpu)
557
+ no_sample | action `store_true` | Set to use greedy decoding instead of sampling
558
+ max_length | `int` | `20` | Maximum length of the output utterances
559
+ min_length | `int` | `1` | Minimum length of the output utterances
560
+ seed | `int` | `42` | Seed
561
+ temperature | `int` | `0.7` | Sampling softmax temperature
562
+ top_k | `int` | `0` | Filter top-k tokens before sampling (`<=0`: no filtering)
563
+ top_p | `float` | `0.9` | Nucleus filtering (top-p) before sampling (`<=0.0`: no filtering)
564
+
565
+ -->
566
+
567
+ ## Related Project
568
+
569
+ * RetGen: [https://github.com/dreasysnail/RetGen](https://github.com/dreasysnail/RetGen). Retrieval-augmented/grounded DialoGPT and beyond. RetGen is a joint training framework that simultaneously optimizes a dense passage retriever and a knowledge-grounded text generator in an end-to-end fashion.
570
+
571
+ * Microsoft ICECAPS: [https://github.com/microsoft/icecaps](https://github.com/microsoft/icecaps).
572
+
573
+ As an orthogonal repository of this project,
574
+ Microsoft Icecaps is an open-source toolkit (in tensorflow) for building neural conversational systems. Icecaps provides an array of tools from recent conversation modeling and general NLP literature within a flexible paradigm that enables complex multi-task learning setups.
575
+
576
+ * Pretrained UniLM: [https://github.com/microsoft/unilm](https://github.com/microsoft/unilm)
577
+ * MT-DNN: [https://github.com/namisan/mt-dnn](https://github.com/namisan/mt-dnn)
578
+ * A chinese counterpart of DialoGPT by yangjianxin1. [https://github.com/yangjianxin1/GPT2-chitchat](https://github.com/yangjianxin1/GPT2-chitchat). We are glad to see that the MMI strategy that we used in DialoGPT has also improved the performance for this project as well!
579
+
580
+ ## Contact
581
+
582
+ Please contact [[email protected]](mailto:[email protected]) if you have any questions/suggestions. However, the response will be sporadic. Please expect delay.
583
+
584
+ ## Contributing
585
+
586
+ This project welcomes contributions and suggestions. Most contributions require you to agree to a
587
+ Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
588
+ the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
589
+
590
+ When you submit a pull request, a CLA bot will automatically determine whether you need to provide
591
+ a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
592
+ provided by the bot. You will only need to do this once across all repos using our CLA.
593
+
594
+ This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
595
+ For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
596
+ contact [[email protected]](mailto:[email protected]) with any additional questions or comments.
597
+
598
+ ## Disclaimer
599
+
600
+ This repository aims to facilitate research in large-scale pretraining for conversational data. This toolkit contains only part of the modeling machinery needed to actually produce a model weight file in a running dialog. On its own, this model provides only information about the weights of various text spans; in order for a researcher to actually use it, they will need to bring conversational data of their own and decode the response generation from the pretrained system. Microsoft is not responsible for any generation from the 3rd party utilization of the pretrained system.
601
+
602
+
603
+
604
+ ## Citation
605
+ If you use this code in your research, you can cite our [arxiv paper](https://arxiv.org/abs/1911.00536):
606
+ ```bash
607
+ @inproceedings{zhang2019dialogpt,
608
+ title={DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation},
609
+ author={Yizhe Zhang and Siqi Sun and Michel Galley and Yen-Chun Chen and Chris Brockett and Xiang Gao and Jianfeng Gao and Jingjing Liu and Bill Dolan},
610
+ year={2020},
611
+ booktitle={ACL, system demonstration}
612
+ }
613
+ ```
614
+
615
+
616
+
SECURITY.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- BEGIN MICROSOFT SECURITY.MD V0.0.1 BLOCK -->
2
+
3
+ ## Security
4
+
5
+ Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [many more](https://opensource.microsoft.com/).
6
+
7
+ If you believe you have found a security vulnerability in any Microsoft-owned repository that meets Microsoft's [definition](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)) of a security vulnerability, please report it to us as described below.
8
+
9
+ ## Reporting Security Issues
10
+
11
+ **Please do not report security vulnerabilities through public GitHub issues.** Instead, please report them to the Microsoft Security Response Center at [[email protected]](mailto:[email protected]). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://technet.microsoft.com/en-us/security/dn606155).
12
+
13
+ You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
14
+
15
+ Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
16
+
17
+ * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
18
+ * Full paths of source file(s) related to the manifestation of the issue
19
+ * The location of the affected source code (tag/branch/commit or direct URL)
20
+ * Any special configuration required to reproduce the issue
21
+ * Step-by-step instructions to reproduce the issue
22
+ * Proof-of-concept or exploit code (if possible)
23
+ * Impact of the issue, including how an attacker might exploit the issue
24
+
25
+ This information will help us triage your report more quickly.
26
+
27
+ ## Preferred Languages
28
+
29
+ We prefer all communications to be in English.
30
+
31
+ ## Policy
32
+
33
+ Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
34
+
35
+ <!-- END MICROSOFT SECURITY.MD BLOCK -->
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_function": "gelu_new",
3
+ "architectures": [
4
+ "GPT2LMHeadModel"
5
+ ],
6
+ "attn_pdrop": 0.1,
7
+ "bos_token_id": 50256,
8
+ "embd_pdrop": 0.1,
9
+ "eos_token_id": 50256,
10
+ "initializer_range": 0.02,
11
+ "layer_norm_epsilon": 1e-05,
12
+ "model_type": "gpt2",
13
+ "n_ctx": 1024,
14
+ "n_embd": 1024,
15
+ "n_head": 16,
16
+ "n_layer": 24,
17
+ "n_positions": 1024,
18
+ "resid_pdrop": 0.1,
19
+ "summary_activation": null,
20
+ "summary_first_dropout": 0.1,
21
+ "summary_proj_to_labels": true,
22
+ "summary_type": "cls_index",
23
+ "summary_use_proj": true,
24
+ "task_specific_params": {
25
+ "conversational": {
26
+ "max_length": 1000
27
+ }
28
+ },
29
+ "vocab_size": 50257
30
+ }
data_config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+ import os
4
+ from . import proj_env
5
+
6
+ RAW_DATA_DIR = os.path.join(proj_env.ROOT_DIR, "raw_data")
7
+ PROCESSED_DATA_DIR = os.path.join(proj_env.ROOT_DIR, "processed")
8
+ PIPELINE_DATA_DIR = os.path.join(proj_env.ROOT_DIR, "pipeline_data")
9
+
10
+ TEST_KEY_FN = os.path.join(RAW_DATA_DIR, "keys.2k.txt")
11
+
12
+
13
+ MAX_LEN = 128 #512
14
+ MAX_CONTEXT_LEN = 64#250
15
+
16
+ TAG_LIST = ["<p>", "<title>", "<anchor>"] + ["<h%s>" % i for i in range(1, 7)]
17
+
data_loader.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+ import gzip
4
+ import json
5
+ import math
6
+ import random
7
+ import shelve
8
+ import torch
9
+
10
+ import subprocess as sp
11
+
12
+ from math import ceil
13
+ from torch.utils.data import DataLoader, Sampler, Dataset
14
+ from torch.nn.utils.rnn import pad_sequence
15
+
16
+ from env import END_OF_TEXT_TOKEN
17
+ from gpt2_training.train_utils import (InputFeatures, InputFeatures_train,
18
+ RedditExample)
19
+
20
+
21
+ class BucketSampler(Sampler):
22
+ """
23
+ this sampler will sort data by sequence length
24
+ """
25
+ def __init__(self, lens, bucket_size, batch_size,
26
+ droplast=False, shuffle=True):
27
+ self._lens = lens
28
+ self._batch_size = batch_size
29
+ self._bucket_size = bucket_size
30
+ self._droplast = droplast
31
+ self._shuf = shuffle
32
+
33
+ def __iter__(self):
34
+ ids = list(range(len(self._lens)))
35
+ if self._shuf:
36
+ random.shuffle(ids)
37
+ buckets = [sorted(ids[i:i+self._bucket_size],
38
+ key=lambda i: self._lens[i], reverse=True)
39
+ for i in range(0, len(ids), self._bucket_size)]
40
+ batches = [bucket[i:i+self._batch_size]
41
+ for bucket in buckets
42
+ for i in range(0, len(bucket), self._batch_size)]
43
+ if self._droplast:
44
+ batches = [batch for batch in batches
45
+ if len(batch) == self._batch_size]
46
+ if self._shuf:
47
+ random.shuffle(batches)
48
+ return iter(batches)
49
+
50
+ def __len__(self):
51
+ bucket_sizes = ([self._bucket_size]
52
+ * (len(self._lens) // self._bucket_size)
53
+ + [len(self._lens) % self._bucket_size])
54
+ if self._droplast:
55
+ return sum(s//self._batch_size for s in bucket_sizes)
56
+ else:
57
+ return sum(math.ceil(s/self._batch_size) for s in bucket_sizes)
58
+
59
+
60
+ class GPT2FeatureDataset(Dataset):
61
+ """ pytorch dataset for GPT2 training """
62
+ def __init__(self, features, max_len=None):
63
+ self.features = features
64
+ self.max_len = max_len # this max_len do truncate
65
+
66
+ def __getitem__(self, i):
67
+ feat_dict = self.features[i]
68
+ if self.max_len is not None and feat_dict['input_len'] > self.max_len:
69
+ # tuncate on the left side (context)
70
+ feat_dict['input_ids'] = feat_dict['input_ids'][-self.max_len:]
71
+ feat_dict['position_ids'] = feat_dict['position_ids'][
72
+ -self.max_len:]
73
+ feat_dict['token_type_ids'] = feat_dict['token_type_ids'][
74
+ -self.max_len:]
75
+ feat_dict['lm_labels'] = feat_dict['lm_labels'][-self.max_len:]
76
+ try:
77
+ for s in ['context_len', 'response_len']:
78
+ if s in feat_dict.keys():
79
+ print("db file missing "+s)
80
+ del feat_dict[s]
81
+ except Exception:
82
+ import pdb
83
+ pdb.set_trace()
84
+
85
+ feat = InputFeatures_train(**feat_dict)
86
+ return feat
87
+
88
+ def __len__(self):
89
+ return len(self.features)
90
+
91
+ @staticmethod
92
+ def collate(features):
93
+ input_ids = pad_sequence([torch.tensor(f.input_ids, dtype=torch.long)
94
+ for f in features],
95
+ batch_first=True, padding_value=0)
96
+ position_ids = pad_sequence([torch.tensor(f.position_ids,
97
+ dtype=torch.long)
98
+ for f in features],
99
+ batch_first=True, padding_value=0)
100
+ token_type_ids = pad_sequence([torch.tensor(f.token_type_ids,
101
+ dtype=torch.long)
102
+ for f in features],
103
+ batch_first=True, padding_value=0)
104
+ labels = pad_sequence([torch.tensor(f.lm_labels, dtype=torch.long)
105
+ for f in features],
106
+ batch_first=True, padding_value=-1)
107
+ return (input_ids, position_ids, token_type_ids, labels)
108
+
109
+
110
+ class BucketingDataLoader(object):
111
+ """ this loads shelve db chunks and then convert to mini-batch loader"""
112
+ def __init__(self, db_name, batch_size, max_seq_len,
113
+ bucket=100, shuffle=True):
114
+ self.db = shelve.open(f'{db_name}/db', 'r')
115
+ self.batch_size = batch_size
116
+ self.max_len = max_seq_len
117
+ self.bucket_size = bucket * batch_size
118
+ self.shuffle = shuffle
119
+
120
+ def _get_keys(self):
121
+ keys = list(self.db.keys())
122
+ return keys
123
+
124
+ def __iter__(self):
125
+ keys = self._get_keys()
126
+ if self.shuffle:
127
+ random.shuffle(keys)
128
+ for key in keys:
129
+ chunk = json.loads(gzip.decompress(self.db[key]).decode('utf-8'))
130
+ # discard long examples
131
+ trunc_chunk = []
132
+ lens = []
133
+ for feat in chunk:
134
+ if feat['input_len'] > self.max_len:
135
+ continue
136
+ trunc_chunk.append(feat)
137
+ lens.append(feat['input_len'])
138
+
139
+ dataset = GPT2FeatureDataset(trunc_chunk, self.max_len)
140
+ sampler = BucketSampler(lens, self.bucket_size, self.batch_size,
141
+ droplast=True, shuffle=self.shuffle)
142
+ loader = DataLoader(dataset, batch_sampler=sampler,
143
+ num_workers=0, # can test multi-worker
144
+ collate_fn=GPT2FeatureDataset.collate)
145
+ yield from loader
146
+
147
+ def __len__(self):
148
+ raise NotImplementedError()
149
+
150
+ def __del__(self):
151
+ self.db.close()
152
+
153
+
154
+ class DistributedBucketingDataLoader(BucketingDataLoader):
155
+ """ distributed version """
156
+ def __init__(self, rank, num_replica, *args, **kwargs):
157
+ super().__init__(*args, **kwargs)
158
+ self.rank = rank
159
+ self.num_replica = num_replica
160
+
161
+ def _get_keys(self):
162
+ keys = list(self.db.keys())[self.rank::self.num_replica]
163
+ return keys
164
+
165
+
166
+ def convert_examples_to_features_dynamic(examples, tokenizer,
167
+ max_seq_length=512):
168
+ """
169
+ do not pad
170
+ """
171
+ def featurize(example):
172
+ conv_id = example.conv_id
173
+ context_id = tokenizer.encode(example.context)
174
+ end_of_text_id = tokenizer.encoder[END_OF_TEXT_TOKEN]
175
+
176
+ # response is provided in example
177
+ response_id = tokenizer.encode(example.response)
178
+
179
+ input_ids_len = len(context_id) + len(response_id) + 2
180
+ if input_ids_len > max_seq_length:
181
+ if len(context_id) > input_ids_len - max_seq_length:
182
+ # cut context from beginning if length of context + response is too long
183
+ # and len of context is long enough to cut
184
+ context_id = context_id[input_ids_len - max_seq_length:]
185
+ else:
186
+ # cut response from end if length of context + response is too long
187
+ # and len of response is long enough to cut
188
+ # if no response is available, discard the data
189
+ if max_seq_length-len(context_id)-2 < 0:
190
+ return None
191
+ response_id = response_id[:max_seq_length-len(context_id)-2]
192
+
193
+ input_ids = context_id + [end_of_text_id] + response_id + [end_of_text_id]
194
+
195
+ # label simplely is next token in sequences. MASK all context_id tokens except for the last one
196
+ lm_labels = [-1] * len(context_id) + response_id + [end_of_text_id] + [-1]
197
+
198
+ position_ids = list(range(len(input_ids)))
199
+
200
+ token_type_id = [0] * len(input_ids)
201
+
202
+ return InputFeatures(conv_id, input_ids, position_ids, token_type_id,
203
+ lm_labels, len(context_id), len(response_id))
204
+
205
+ # discard None feature
206
+ features = [f for f in [featurize(ex) for ex in examples] if f is not None]
207
+ return features
208
+
209
+
210
+ class DynamicBatchingLoader(object):
211
+ """ this loader takes raw text file, used for validate perplexity """
212
+ def __init__(self, corpus_file, tokenizer, normalize_data,
213
+ batch_size, max_seq_length):
214
+ self.corpus = corpus_file
215
+ self.toker = tokenizer
216
+ self.norm = normalize_data
217
+ self.bs = batch_size
218
+ self.max_seq_length = max_seq_length
219
+ self.num_examples = self.get_len(corpus_file)
220
+
221
+ def __iter__(self, epoch=1):
222
+ if epoch > 0:
223
+ for epoch in range(epoch):
224
+ yield from self._iter_epoch()
225
+ else:
226
+ while True:
227
+ yield from self._iter_epoch()
228
+
229
+ def __len__(self):
230
+ return ceil(self.num_examples/self.bs)
231
+
232
+ def _iter_epoch(self):
233
+ try:
234
+ with open(self.corpus, 'r', encoding="utf-8") as corpus:
235
+ i = 0
236
+ while True:
237
+ examples = []
238
+ cur_bs = 0
239
+ while True:
240
+ line = next(corpus).encode('utf-8').decode('utf-8')
241
+ contents = line.split('\t')
242
+ src, tgt_all = contents[0], contents[1:]
243
+ for tgt in tgt_all:
244
+ if self.norm:
245
+ src_line = ' '.join(src.strip().split())
246
+ tgt_line = ' '.join(tgt.strip().split())
247
+ else:
248
+ src_line = src.strip()
249
+ tgt_line = tgt.strip()
250
+ examples.append(
251
+ RedditExample(i, src_line, tgt_line),
252
+ )
253
+ i += 1
254
+ cur_bs += 1
255
+ if cur_bs >= self.bs:
256
+ break
257
+ features = convert_examples_to_features_dynamic(
258
+ examples, self.toker, self.max_seq_length)
259
+ batch = self._batch_feature(features)
260
+ yield batch
261
+ except StopIteration:
262
+ pass
263
+
264
+ def _batch_feature(self, features):
265
+ input_ids = pad_sequence([torch.tensor(f.choices_features['input_ids'],
266
+ dtype=torch.long)
267
+ for f in features],
268
+ batch_first=True, padding_value=0)
269
+ position_ids = pad_sequence(
270
+ [torch.tensor(f.choices_features['position_ids'], dtype=torch.long)
271
+ for f in features],
272
+ batch_first=True, padding_value=0)
273
+ token_type_ids = pad_sequence(
274
+ [torch.tensor(f.choices_features['token_type_ids'],
275
+ dtype=torch.long)
276
+ for f in features],
277
+ batch_first=True, padding_value=0)
278
+ labels = pad_sequence([torch.tensor(f.lm_labels, dtype=torch.long)
279
+ for f in features],
280
+ batch_first=True, padding_value=-1)
281
+ context_len = torch.tensor([f.context_len for f in features],
282
+ dtype=torch.long)
283
+ response_len = torch.tensor([f.response_len for f in features],
284
+ dtype=torch.long)
285
+ return (input_ids, position_ids, token_type_ids, labels,
286
+ context_len, response_len)
287
+
288
+ def get_len(self, corpus):
289
+ n_line = int(sp.check_output(f"wc -l {corpus}".split(),
290
+ universal_newlines=True).split()[0])
291
+ return n_line
demo.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+ #
4
+ # Please assign the DATA_FOLDER before running this scripts, the data, pre-trained model, fine-tuned model will be
5
+ # downloaded automatically to DATA_FOLDER
6
+
7
+ import os
8
+ import sys
9
+ import logging
10
+ from functools import partial
11
+
12
+ from demo_utils import download_model_folder
13
+ import argparse
14
+ import subprocess as sp
15
+
16
+
17
+ PROJECT_FOLDER = os.path.dirname(os.path.realpath(__file__))
18
+ PYTHON_EXE = 'python'
19
+ MODEL_FOLDER = os.path.join(PROJECT_FOLDER, 'models')
20
+ DATA_FOLDER = os.path.join(PROJECT_FOLDER, 'data')
21
+
22
+ print(f'PROJECT_FOLDER = {PROJECT_FOLDER}')
23
+
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument('--data', type=str, default='dummy',
26
+ help='choose from dummy, small and full')
27
+ dargs = parser.parse_args()
28
+
29
+ assert dargs.data == 'dummy' or dargs.data == 'small' or dargs.data == 'full' , \
30
+ 'The specified data option is not support!'
31
+
32
+
33
+ logging.basicConfig(
34
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
35
+ datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO
36
+ )
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ if os.path.exists(MODEL_FOLDER):
41
+ print(f'Found existing models folder at {MODEL_FOLDER}, skip creating a new one!')
42
+ os.makedirs(MODEL_FOLDER, exist_ok=True)
43
+ else:
44
+ os.makedirs(MODEL_FOLDER)
45
+
46
+ #########################################################################
47
+ # Download Model
48
+ #########################################################################
49
+ logger.info('Downloading models...')
50
+ download_model = partial(download_model_folder, DATA_FOLDER=MODEL_FOLDER)
51
+
52
+ # model size: could be one of 'small' (GPT2 with 117M), 'medium'(345M) or 'large' (1542M)
53
+ # dataset: one of 'multiref' or 'dstc'
54
+ # from_scratch: True : load model trained from scratch or False: load model trained from fine-tuning the GPT-2
55
+ target_folder = download_model(model_size='small', dataset='multiref', from_scratch=False)
56
+ logger.info('Done!\n')
57
+
58
+
59
+ #########################################################################
60
+ # Prepare Data
61
+ #########################################################################
62
+ logger.info('Downloading and Extracting Data...')
63
+ if dargs.data == 'dummy':
64
+ cmd = 'bash prepare4db.sh'
65
+ ret = sp.run(cmd.split(' '), stdout=sp.PIPE, stderr=sp.STDOUT, cwd=DATA_FOLDER)
66
+ elif dargs.data == 'small':
67
+ myCmd = os.popen('cd reddit_extractor; SIZE=small make -j 8; cd ..').read()
68
+ elif dargs.data == 'full':
69
+ myCmd = os.popen('cd reddit_extractor; SIZE=full make -j 8; cd ..').read()
70
+ else:
71
+ raise ValueError('you need to implement your own data type, or use either dummy, small, or full')
72
+
73
+ logger.info('Preparing Data...')
74
+ data_path = os.path.join(DATA_FOLDER, 'train.tsv')
75
+ MAX_LEN = 128
76
+ data_db = f'{data_path[:-4]}.{MAX_LEN}len.db'
77
+ if os.path.isdir(data_db):
78
+ print(f'{data_db} exists, skip prepro.py')
79
+ else:
80
+ cmd = ['prepro.py', '--corpus', data_path, '--max_seq_len', f'{MAX_LEN}']
81
+ cmd = ' '.join(cmd) #% {'CODE_ROOT': CODE_ROOT}
82
+ print(cmd)
83
+ ret = sp.run([PYTHON_EXE] + cmd.split(' '), stdout=sp.PIPE, stderr=sp.STDOUT, cwd=PROJECT_FOLDER)
84
+ if ret.returncode != 0:
85
+ print(f'error occurred, {ret.stdout}')
86
+ sys.exit(ret.returncode)
87
+ logger.info('Done!\n')
88
+
89
+ #########################################################################
90
+ # Train !
91
+ #########################################################################
92
+ logger.info('Generating training CMD!')
93
+ logger.info('If there is any problem, please copy (modify) and run command below')
94
+ logger.info('#########################################################################')
95
+ train_cmd = 'LSP_train.py'
96
+ args = [
97
+ '--model_name_or_path', target_folder,
98
+ '--init_checkpoint', os.path.join(target_folder, 'pytorch_model.bin'),
99
+ '--train_input_file', data_db , # file from last step
100
+ '--eval_input_file', './data/dummy_data.tsv', # dummy test data
101
+ '--output_dir', os.path.join(MODEL_FOLDER, 'output_model'),
102
+ '--seed', '42',
103
+ '--max_seq_length', '128',
104
+ '--train_batch_size', '512',
105
+ '--gradient_accumulation_steps', '8',
106
+ '--eval_batch_size', '64',
107
+ '--learning_rate', '1e-5',
108
+ '--num_optim_steps', '10000',
109
+ '--valid_step', '5000',
110
+ '--warmup_steps', '4000',
111
+ '--normalize_data', 'true',
112
+ '--fp16', 'true',
113
+ '--lr_schedule', 'noam',
114
+ '--loss_scale', '0.0',
115
+ '--no_token_id', 'true',
116
+ '--pbar', 'true'
117
+ ]
118
+
119
+ arg = ' '.join(args)
120
+ train_cmd = train_cmd + ' ' + arg
121
+ print(PYTHON_EXE + ' ' +train_cmd)
122
+ logger.info('#########################################################################')
123
+ with open('./output.log', 'wb') as f:
124
+ process = sp.Popen([PYTHON_EXE] + train_cmd.split(' '), stdout=sp.PIPE, stderr=sp.STDOUT, cwd=PROJECT_FOLDER)
125
+ for line in iter(process.stdout.readline, b''):
126
+ sys.stdout.write(line.decode(sys.stdout.encoding))
127
+ f.write(line)
128
+ logger.info('Done!\n')
demo_utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+
4
+ import os
5
+ import logging
6
+
7
+ from pytorch_pretrained_bert.file_utils import http_get
8
+
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ # Note that the model size is roughly half of the GPT model because our model is saved by fp16
14
+ LSP_MODEL_URL = {
15
+ 'multiref': {
16
+ 'large_fs': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/large_fs.pkl',
17
+ 'medium_fs': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/medium_fs.pkl',
18
+ 'medium_ft': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/medium_ft.pkl',
19
+ 'small_fs': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/small_fs.pkl',
20
+ 'small_ft': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/small_ft.pkl'
21
+ },
22
+ 'dstc': {
23
+ 'medium_ft': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/DSTC/medium_ft.pkl'
24
+ }
25
+ }
26
+
27
+ # GPT model could be downloaded from huggingface repo
28
+ GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {
29
+ "small": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
30
+ "medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
31
+ "large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin"
32
+ }
33
+
34
+ CONFIG_FILE = {
35
+ 'small': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/117M/config.json',
36
+ 'medium': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/345M/config.json',
37
+ 'large': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/1542M/config.json'
38
+ }
39
+
40
+ VOCAB_FILE = {
41
+ 'small': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/117M/vocab.json',
42
+ 'medium': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/345M/vocab.json',
43
+ 'large': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/1542M/vocab.json'
44
+ }
45
+
46
+ MERGE_FILE = {
47
+ 'small': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/117M/merges.txt',
48
+ 'medium': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/345M/merges.txt',
49
+ 'large': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/1542M/merges.txt'
50
+ }
51
+
52
+
53
+ def download_file(url, folder):
54
+ if not os.path.exists(folder):
55
+ os.makedirs(folder, exist_ok=True)
56
+
57
+ file_name = os.path.basename(url)
58
+ if 'pytorch_model.bin' in file_name:
59
+ file_name = 'pytorch_model.bin'
60
+
61
+ if os.path.isfile(os.path.join(folder, file_name)):
62
+ logger.info(f'{os.path.join(folder, file_name)} exists, return!')
63
+ return
64
+
65
+ with open(os.path.join(folder, file_name), 'wb') as f:
66
+ http_get(url, f)
67
+
68
+
69
+ def download_model_folder(model_size, dataset=None, from_scratch=None, DATA_FOLDER=None):
70
+ assert DATA_FOLDER is not None, 'DATA_FOLDER cannot be None'
71
+ assert model_size in ['small', 'medium', 'large'], 'model size should be one of \'small\', \'medium\' or \'large\''
72
+ target_folder = os.path.join(DATA_FOLDER, model_size)
73
+ download_file(CONFIG_FILE[model_size], target_folder)
74
+ download_file(VOCAB_FILE[model_size], target_folder)
75
+ download_file(MERGE_FILE[model_size], target_folder)
76
+ download_file(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP[model_size], target_folder)
77
+ if dataset is not None:
78
+ assert dataset in ['multiref', 'dstc'], \
79
+ 'dataset has to be \'multiref\' or \'dstc\''
80
+ assert from_scratch in [True, False], 'from scratch has to be True or False'
81
+
82
+ if from_scratch:
83
+ model_train_type = model_size + '_fs'
84
+ else:
85
+ model_train_type = model_size + '_ft'
86
+ if model_train_type not in LSP_MODEL_URL[dataset]:
87
+ k = ','.join(list(LSP_MODEL_URL[dataset].keys()))
88
+ raise ValueError(f'\'{model_train_type}\' not exist for dataset \'{dataset}\', please choose from [{k}]')
89
+ download_file(LSP_MODEL_URL[dataset][model_train_type], target_folder)
90
+ return target_folder
91
+
env.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+ import os
4
+
5
+
6
+ END_OF_TURN_TOKEN = '<|endofturn|>'
7
+ END_OF_TEXT_TOKEN = '<|endoftext|>'
8
+ PROJECT_FOLDER = os.path.dirname(__file__)
gradiodemo.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A 3rd party demo contributed by Github user AK391 (https://github.com/AK391). This is not implemented by Microsoft and Microsoft do not own any IP with this implementation and associated demo.
2
+ # Microsoft has not tested the generation of this demo and is not responsible for any offensive or biased generation from this demo.
3
+ # Please contact the creator AK391 (https://github.com/AK391) for any potential issue.
4
+
5
+
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ import torch
8
+ import gradio as gr
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
11
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
12
+
13
+ def dialogpt(text):
14
+ # encode the new user input, add the eos_token and return a tensor in Pytorch
15
+ for step in range(50000):
16
+
17
+ new_user_input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors='pt')
18
+
19
+ # append the new user input tokens to the chat history
20
+ bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
21
+
22
+ # generated a response while limiting the total chat history to 1000 tokens,
23
+ chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
24
+
25
+ # pretty print last ouput tokens from bot
26
+ return tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
27
+
28
+ inputs = gr.inputs.Textbox(lines=1, label="Input Text")
29
+ outputs = gr.outputs.Textbox(label="DialoGPT")
30
+
31
+ title = "DialoGPT"
32
+ description = "demo for Microsoft DialoGPT with Hugging Face transformers. To use it, simply input text or click one of the examples text to load them. Read more at the links below. *This is not a Microsoft product and is developed for Gradio*"
33
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1911.00536'>DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation</a> | <a href='https://github.com/microsoft/DialoGPT'>Github Repo</a> | <a href='https://huggingface.co/microsoft/DialoGPT-large'>Hugging Face DialoGPT-large</a></p>"
34
+ examples = [
35
+ ["Hi, how are you?"],
36
+ ["How far away is the moon?"],
37
+ ]
38
+
39
+ gr.Interface(dialogpt, inputs, outputs, title=title, description=description, article=article, examples=examples).launch(debug=True)
prepro.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+ """
4
+ preprocess input data into feature and stores binary as python shelve DB
5
+ each chunk is gzipped JSON string
6
+ """
7
+ import argparse
8
+ import gzip
9
+ import json
10
+ import subprocess as sp
11
+ import shelve
12
+ import os
13
+ from os.path import dirname, exists, join
14
+
15
+ import torch
16
+ from lsp_model import GPT2Tokenizer
17
+ from tqdm import tqdm
18
+
19
+ from env import END_OF_TEXT_TOKEN
20
+ from gpt2_training.train_utils import InputFeatures_train as InputFeatures
21
+
22
+
23
+ def _get_file_len(corpus):
24
+ n_line = int(sp.check_output(f"wc -l {corpus}".split(),
25
+ universal_newlines=True).split()[0])
26
+ return n_line
27
+
28
+
29
+ def _norm_text(text):
30
+ w, *toks = text.strip().split()
31
+ try:
32
+ w = float(w)
33
+ except Exception:
34
+ toks = [w] + toks
35
+ w = 1.0
36
+ return w, ' '.join(toks)
37
+
38
+
39
+ def _get_inputs_from_text(text, tokenizer):
40
+ srcs, tgt = text.strip().split('\t')
41
+ weights = []
42
+ inputs = []
43
+ for src in srcs.split(' EOS '):
44
+ src_weight, src = _norm_text(src)
45
+ context_id = tokenizer.encode(src)
46
+ weights.append(src_weight)
47
+ inputs.append(context_id)
48
+ tgt_weight, tgt = _norm_text(tgt)
49
+ if tgt_weight != 0:
50
+ response_id = tokenizer.encode(tgt)
51
+ weights.append(tgt_weight)
52
+ inputs.append(response_id)
53
+ return weights, inputs
54
+
55
+
56
+ def _make_features(id_, weights, inputs, tokenizer, max_len):
57
+ end_of_text_id = tokenizer.encoder[END_OF_TEXT_TOKEN]
58
+ features = []
59
+ sents = []
60
+ ws = []
61
+ len_ = 0
62
+ i = 0
63
+ for ids, w in zip(inputs, weights):
64
+ if len(ids) > max_len:
65
+ if len(sents) >= 2:
66
+ feat = _make_feature(id_ + i, sents, ws, end_of_text_id)
67
+ if feat is not None:
68
+ features.append(feat)
69
+ i += 1
70
+ len_ = 0
71
+ sents = []
72
+ ws = []
73
+ continue
74
+ elif len_ > max_len:
75
+ feat = _make_feature(id_ + i, sents, ws, end_of_text_id)
76
+ if feat is not None:
77
+ features.append(feat)
78
+ i += 1
79
+ len_ = len(sents[-1]) + 1
80
+ sents = sents[-1:]
81
+ ws = ws[-1:]
82
+ len_ += (len(ids) + 1)
83
+ sents.append(ids)
84
+ ws.append(w)
85
+ if len(sents) >= 2:
86
+ feat = _make_feature(id_ + i, sents, ws, end_of_text_id)
87
+ if feat is not None:
88
+ features.append(feat)
89
+
90
+ return features
91
+
92
+
93
+ def _make_feature(id_, sents, ws, eos):
94
+ if all(w == 0 for w in ws[1:]):
95
+ return None
96
+ input_ids = [i for s in sents for i in s+[eos]][:-1]
97
+ lm_labels = []
98
+ weights = []
99
+ token_type_ids = [] # this becomes round ids
100
+ for i, (s, w) in enumerate(zip(sents, ws)):
101
+ if i == 0:
102
+ lm_labels += [-1] * len(s)
103
+ weights += [0.0] * len(s)
104
+ token_type_ids += [0] * len(s)
105
+ continue
106
+
107
+ token_type_ids += [i] * (len(s) + 1)
108
+ if w == 0.0:
109
+ lm_labels += [-1] * (len(s) + 1)
110
+ weights += [0.0] * (len(s) + 1)
111
+ else:
112
+ lm_labels += (s + [eos])
113
+ weights += [w] * (len(s) + 1)
114
+
115
+ # handle trailing -1's
116
+ i = len(lm_labels) - 1
117
+ while i >= 0:
118
+ if lm_labels[i] != -1:
119
+ break
120
+ i -= 1
121
+ input_ids = input_ids[:i+1]
122
+ lm_labels = lm_labels[:i+1]
123
+ weights = weights[:i+1]
124
+ token_type_ids = token_type_ids[:i+1]
125
+
126
+ # pad to multiples of 8
127
+ while len(input_ids) % 8 != 0:
128
+ input_ids.append(0)
129
+ token_type_ids.append(0)
130
+ lm_labels.append(-1)
131
+ weights.append(0.0)
132
+
133
+ position_ids = list(range(len(input_ids)))
134
+ assert (len(input_ids) == len(position_ids) == len(token_type_ids)
135
+ == len(lm_labels) == len(weights))
136
+ assert len(input_ids) % 8 == 0
137
+ if len(input_ids) == 0:
138
+ import pdb
139
+ pdb.set_trace()
140
+ feature = InputFeatures(id_, input_ids, position_ids, token_type_ids,
141
+ lm_labels, weights)
142
+ return feature
143
+
144
+
145
+ def main(args):
146
+ toker = GPT2Tokenizer.from_pretrained('gpt2')
147
+ attrs = []
148
+ if args.reverse:
149
+ attrs.append('reverse')
150
+ if args.two_turn:
151
+ attrs.append('2turn')
152
+ if attrs:
153
+ db_path = (f'{args.corpus[:-4]}.{args.max_seq_len}len.'
154
+ f'{".".join(attrs)}.db/db')
155
+ else:
156
+ db_path = f'{args.corpus[:-4]}.{args.max_seq_len}len.db/db'
157
+ if exists(dirname(db_path)):
158
+ raise ValueError('Found existing DB, please backup')
159
+ else:
160
+ os.makedirs(dirname(db_path))
161
+ with open(args.corpus, "r", encoding="utf-8") as reader, \
162
+ shelve.open(db_path, 'n') as db:
163
+ chunk = []
164
+ n_chunk = 0
165
+ n_example = 0
166
+ for line in tqdm(reader, total=_get_file_len(args.corpus)):
167
+ try:
168
+ if len(chunk) >= args.chunk_size:
169
+ # save and renew chunk
170
+ db[f'chunk_{n_chunk}'] = gzip.compress(
171
+ json.dumps(chunk[:args.chunk_size]).encode('utf-8'))
172
+ chunk = chunk[args.chunk_size:]
173
+ n_chunk += 1
174
+
175
+ weights, inputs = _get_inputs_from_text(line, toker)
176
+ if args.reverse:
177
+ weights = list(reversed(weights))
178
+ inputs = list(reversed(inputs))
179
+ if args.two_turn:
180
+ weights = weights[:2]
181
+ inputs = inputs[:2]
182
+ if len(weights) < 2:
183
+ continue
184
+ features = _make_features(n_example, weights, inputs,
185
+ toker, args.max_seq_len)
186
+ for feature in features:
187
+ chunk.append(vars(feature))
188
+ n_example += 1
189
+ except Exception as e:
190
+ print('!!! prepro exception !!!', e)
191
+ continue
192
+ # save last chunk
193
+ db[f'chunk_{n_chunk}'] = gzip.compress(
194
+ json.dumps(chunk).encode('utf-8'))
195
+ # save relevant information to reproduce
196
+ meta = {'n_example': n_example,
197
+ 'chunk_size': args.chunk_size,
198
+ 'max_seq_len': args.max_seq_len,
199
+ 'reverse': args.reverse,
200
+ 'two_turn': args.two_turn}
201
+ with open(join(dirname(db_path), 'meta.json'), 'w') as writer:
202
+ json.dump(meta, writer, indent=4)
203
+ torch.save(toker, join(dirname(db_path), 'tokenizer.pt'))
204
+
205
+
206
+ if __name__ == '__main__':
207
+ parser = argparse.ArgumentParser()
208
+ parser.add_argument('--corpus', required=True,
209
+ help='file name of training corpus (should be .tsv)')
210
+ parser.add_argument('--chunk_size', type=int, default=65536,
211
+ help='num of data examples in a storing chunk')
212
+ parser.add_argument('--max_seq_len', type=int, default=128,
213
+ help='discard data longer than this')
214
+ parser.add_argument('--reverse', action='store_true',
215
+ help='reverse the src tgt')
216
+ parser.add_argument('--two_turn', action='store_true',
217
+ help='take only the first 2 turns')
218
+
219
+ args = parser.parse_args()
220
+
221
+ main(args)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ ftfy
3
+ regex
4
+ tqdm
5
+ torch
6
+ torchvision
7
+ pycrypto
8
+ flashtext
9
+ zstandard
10
+ nltk