Upload 17 files
Browse files- .gitignore +475 -0
- LICENSE +21 -0
- LSP-generic.yml +70 -0
- LSP-linux.yml +77 -0
- LSP_train.py +386 -0
- MANIFEST.in +1 -0
- README.md +591 -29
- SECURITY.md +35 -0
- config.json +30 -0
- data_config.py +17 -0
- data_loader.py +291 -0
- demo.py +128 -0
- demo_utils.py +91 -0
- env.py +8 -0
- gradiodemo.py +39 -0
- prepro.py +221 -0
- requirements.txt +10 -0
.gitignore
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Ignore Visual Studio temporary files, build results, and
|
| 2 |
+
## files generated by popular Visual Studio add-ons.
|
| 3 |
+
##
|
| 4 |
+
## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
|
| 5 |
+
|
| 6 |
+
# User-specific files
|
| 7 |
+
*.suo
|
| 8 |
+
*.user
|
| 9 |
+
*.userosscache
|
| 10 |
+
*.sln.docstates
|
| 11 |
+
|
| 12 |
+
# User-specific files (MonoDevelop/Xamarin Studio)
|
| 13 |
+
*.userprefs
|
| 14 |
+
|
| 15 |
+
# Build results
|
| 16 |
+
[Dd]ebug/
|
| 17 |
+
[Dd]ebugPublic/
|
| 18 |
+
[Rr]elease/
|
| 19 |
+
[Rr]eleases/
|
| 20 |
+
x64/
|
| 21 |
+
x86/
|
| 22 |
+
bld/
|
| 23 |
+
[Bb]in/
|
| 24 |
+
[Oo]bj/
|
| 25 |
+
[Ll]og/
|
| 26 |
+
|
| 27 |
+
# Visual Studio 2015/2017 cache/options directory
|
| 28 |
+
.vs/
|
| 29 |
+
# Uncomment if you have tasks that create the project's static files in wwwroot
|
| 30 |
+
#wwwroot/
|
| 31 |
+
|
| 32 |
+
# Visual Studio 2017 auto generated files
|
| 33 |
+
Generated\ Files/
|
| 34 |
+
|
| 35 |
+
# MSTest test Results
|
| 36 |
+
[Tt]est[Rr]esult*/
|
| 37 |
+
[Bb]uild[Ll]og.*
|
| 38 |
+
|
| 39 |
+
# NUNIT
|
| 40 |
+
*.VisualState.xml
|
| 41 |
+
TestResult.xml
|
| 42 |
+
|
| 43 |
+
# Build Results of an ATL Project
|
| 44 |
+
[Dd]ebugPS/
|
| 45 |
+
[Rr]eleasePS/
|
| 46 |
+
dlldata.c
|
| 47 |
+
|
| 48 |
+
# Benchmark Results
|
| 49 |
+
BenchmarkDotNet.Artifacts/
|
| 50 |
+
|
| 51 |
+
# .NET Core
|
| 52 |
+
project.lock.json
|
| 53 |
+
project.fragment.lock.json
|
| 54 |
+
artifacts/
|
| 55 |
+
**/Properties/launchSettings.json
|
| 56 |
+
|
| 57 |
+
# StyleCop
|
| 58 |
+
StyleCopReport.xml
|
| 59 |
+
|
| 60 |
+
# Files built by Visual Studio
|
| 61 |
+
*_i.c
|
| 62 |
+
*_p.c
|
| 63 |
+
*_i.h
|
| 64 |
+
*.ilk
|
| 65 |
+
*.meta
|
| 66 |
+
*.obj
|
| 67 |
+
*.iobj
|
| 68 |
+
*.pch
|
| 69 |
+
*.pdb
|
| 70 |
+
*.ipdb
|
| 71 |
+
*.pgc
|
| 72 |
+
*.pgd
|
| 73 |
+
*.rsp
|
| 74 |
+
*.sbr
|
| 75 |
+
*.tlb
|
| 76 |
+
*.tli
|
| 77 |
+
*.tlh
|
| 78 |
+
*.tmp
|
| 79 |
+
*.tmp_proj
|
| 80 |
+
*.log
|
| 81 |
+
*.vspscc
|
| 82 |
+
*.vssscc
|
| 83 |
+
.builds
|
| 84 |
+
*.pidb
|
| 85 |
+
*.svclog
|
| 86 |
+
*.scc
|
| 87 |
+
|
| 88 |
+
# Chutzpah Test files
|
| 89 |
+
_Chutzpah*
|
| 90 |
+
|
| 91 |
+
# Visual C++ cache files
|
| 92 |
+
ipch/
|
| 93 |
+
*.aps
|
| 94 |
+
*.ncb
|
| 95 |
+
*.opendb
|
| 96 |
+
*.opensdf
|
| 97 |
+
*.sdf
|
| 98 |
+
*.cachefile
|
| 99 |
+
*.VC.db
|
| 100 |
+
*.VC.VC.opendb
|
| 101 |
+
|
| 102 |
+
# Visual Studio profiler
|
| 103 |
+
*.psess
|
| 104 |
+
*.vsp
|
| 105 |
+
*.vspx
|
| 106 |
+
*.sap
|
| 107 |
+
|
| 108 |
+
# Visual Studio Trace Files
|
| 109 |
+
*.e2e
|
| 110 |
+
|
| 111 |
+
# TFS 2012 Local Workspace
|
| 112 |
+
$tf/
|
| 113 |
+
|
| 114 |
+
# Guidance Automation Toolkit
|
| 115 |
+
*.gpState
|
| 116 |
+
|
| 117 |
+
# ReSharper is a .NET coding add-in
|
| 118 |
+
_ReSharper*/
|
| 119 |
+
*.[Rr]e[Ss]harper
|
| 120 |
+
*.DotSettings.user
|
| 121 |
+
|
| 122 |
+
# JustCode is a .NET coding add-in
|
| 123 |
+
.JustCode
|
| 124 |
+
|
| 125 |
+
# TeamCity is a build add-in
|
| 126 |
+
_TeamCity*
|
| 127 |
+
|
| 128 |
+
# DotCover is a Code Coverage Tool
|
| 129 |
+
*.dotCover
|
| 130 |
+
|
| 131 |
+
# AxoCover is a Code Coverage Tool
|
| 132 |
+
.axoCover/*
|
| 133 |
+
!.axoCover/settings.json
|
| 134 |
+
|
| 135 |
+
# Visual Studio code coverage results
|
| 136 |
+
*.coverage
|
| 137 |
+
*.coveragexml
|
| 138 |
+
|
| 139 |
+
# NCrunch
|
| 140 |
+
_NCrunch_*
|
| 141 |
+
.*crunch*.local.xml
|
| 142 |
+
nCrunchTemp_*
|
| 143 |
+
|
| 144 |
+
# MightyMoose
|
| 145 |
+
*.mm.*
|
| 146 |
+
AutoTest.Net/
|
| 147 |
+
|
| 148 |
+
# Web workbench (sass)
|
| 149 |
+
.sass-cache/
|
| 150 |
+
|
| 151 |
+
# Installshield output folder
|
| 152 |
+
[Ee]xpress/
|
| 153 |
+
|
| 154 |
+
# DocProject is a documentation generator add-in
|
| 155 |
+
DocProject/buildhelp/
|
| 156 |
+
DocProject/Help/*.HxT
|
| 157 |
+
DocProject/Help/*.HxC
|
| 158 |
+
DocProject/Help/*.hhc
|
| 159 |
+
DocProject/Help/*.hhk
|
| 160 |
+
DocProject/Help/*.hhp
|
| 161 |
+
DocProject/Help/Html2
|
| 162 |
+
DocProject/Help/html
|
| 163 |
+
|
| 164 |
+
# Click-Once directory
|
| 165 |
+
publish/
|
| 166 |
+
|
| 167 |
+
# Publish Web Output
|
| 168 |
+
*.[Pp]ublish.xml
|
| 169 |
+
*.azurePubxml
|
| 170 |
+
# Note: Comment the next line if you want to checkin your web deploy settings,
|
| 171 |
+
# but database connection strings (with potential passwords) will be unencrypted
|
| 172 |
+
*.pubxml
|
| 173 |
+
*.publishproj
|
| 174 |
+
|
| 175 |
+
# Microsoft Azure Web App publish settings. Comment the next line if you want to
|
| 176 |
+
# checkin your Azure Web App publish settings, but sensitive information contained
|
| 177 |
+
# in these scripts will be unencrypted
|
| 178 |
+
PublishScripts/
|
| 179 |
+
|
| 180 |
+
# NuGet Packages
|
| 181 |
+
*.nupkg
|
| 182 |
+
# The packages folder can be ignored because of Package Restore
|
| 183 |
+
**/[Pp]ackages/*
|
| 184 |
+
# except build/, which is used as an MSBuild target.
|
| 185 |
+
!**/[Pp]ackages/build/
|
| 186 |
+
# Uncomment if necessary however generally it will be regenerated when needed
|
| 187 |
+
#!**/[Pp]ackages/repositories.config
|
| 188 |
+
# NuGet v3's project.json files produces more ignorable files
|
| 189 |
+
*.nuget.props
|
| 190 |
+
*.nuget.targets
|
| 191 |
+
|
| 192 |
+
# Microsoft Azure Build Output
|
| 193 |
+
csx/
|
| 194 |
+
*.build.csdef
|
| 195 |
+
|
| 196 |
+
# Microsoft Azure Emulator
|
| 197 |
+
ecf/
|
| 198 |
+
rcf/
|
| 199 |
+
|
| 200 |
+
# Windows Store app package directories and files
|
| 201 |
+
AppPackages/
|
| 202 |
+
BundleArtifacts/
|
| 203 |
+
Package.StoreAssociation.xml
|
| 204 |
+
_pkginfo.txt
|
| 205 |
+
*.appx
|
| 206 |
+
|
| 207 |
+
# Visual Studio cache files
|
| 208 |
+
# files ending in .cache can be ignored
|
| 209 |
+
*.[Cc]ache
|
| 210 |
+
# but keep track of directories ending in .cache
|
| 211 |
+
!*.[Cc]ache/
|
| 212 |
+
|
| 213 |
+
# Others
|
| 214 |
+
ClientBin/
|
| 215 |
+
~$*
|
| 216 |
+
*~
|
| 217 |
+
*.dbmdl
|
| 218 |
+
*.dbproj.schemaview
|
| 219 |
+
*.jfm
|
| 220 |
+
*.pfx
|
| 221 |
+
*.publishsettings
|
| 222 |
+
orleans.codegen.cs
|
| 223 |
+
|
| 224 |
+
# Including strong name files can present a security risk
|
| 225 |
+
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
|
| 226 |
+
#*.snk
|
| 227 |
+
|
| 228 |
+
# Since there are multiple workflows, uncomment next line to ignore bower_components
|
| 229 |
+
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
|
| 230 |
+
#bower_components/
|
| 231 |
+
|
| 232 |
+
# RIA/Silverlight projects
|
| 233 |
+
Generated_Code/
|
| 234 |
+
|
| 235 |
+
# Backup & report files from converting an old project file
|
| 236 |
+
# to a newer Visual Studio version. Backup files are not needed,
|
| 237 |
+
# because we have git ;-)
|
| 238 |
+
_UpgradeReport_Files/
|
| 239 |
+
Backup*/
|
| 240 |
+
UpgradeLog*.XML
|
| 241 |
+
UpgradeLog*.htm
|
| 242 |
+
ServiceFabricBackup/
|
| 243 |
+
*.rptproj.bak
|
| 244 |
+
|
| 245 |
+
# SQL Server files
|
| 246 |
+
*.mdf
|
| 247 |
+
*.ldf
|
| 248 |
+
*.ndf
|
| 249 |
+
|
| 250 |
+
# Business Intelligence projects
|
| 251 |
+
*.rdl.data
|
| 252 |
+
*.bim.layout
|
| 253 |
+
*.bim_*.settings
|
| 254 |
+
*.rptproj.rsuser
|
| 255 |
+
|
| 256 |
+
# Microsoft Fakes
|
| 257 |
+
FakesAssemblies/
|
| 258 |
+
|
| 259 |
+
# GhostDoc plugin setting file
|
| 260 |
+
*.GhostDoc.xml
|
| 261 |
+
|
| 262 |
+
# Node.js Tools for Visual Studio
|
| 263 |
+
.ntvs_analysis.dat
|
| 264 |
+
node_modules/
|
| 265 |
+
|
| 266 |
+
# Visual Studio 6 build log
|
| 267 |
+
*.plg
|
| 268 |
+
|
| 269 |
+
# Visual Studio 6 workspace options file
|
| 270 |
+
*.opt
|
| 271 |
+
|
| 272 |
+
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
|
| 273 |
+
*.vbw
|
| 274 |
+
|
| 275 |
+
# Visual Studio LightSwitch build output
|
| 276 |
+
**/*.HTMLClient/GeneratedArtifacts
|
| 277 |
+
**/*.DesktopClient/GeneratedArtifacts
|
| 278 |
+
**/*.DesktopClient/ModelManifest.xml
|
| 279 |
+
**/*.Server/GeneratedArtifacts
|
| 280 |
+
**/*.Server/ModelManifest.xml
|
| 281 |
+
_Pvt_Extensions
|
| 282 |
+
|
| 283 |
+
# Paket dependency manager
|
| 284 |
+
.paket/paket.exe
|
| 285 |
+
paket-files/
|
| 286 |
+
|
| 287 |
+
# FAKE - F# Make
|
| 288 |
+
.fake/
|
| 289 |
+
|
| 290 |
+
# JetBrains Rider
|
| 291 |
+
.idea/
|
| 292 |
+
*.sln.iml
|
| 293 |
+
|
| 294 |
+
# CodeRush
|
| 295 |
+
.cr/
|
| 296 |
+
|
| 297 |
+
# Python Tools for Visual Studio (PTVS)
|
| 298 |
+
__pycache__/
|
| 299 |
+
*.pyc
|
| 300 |
+
|
| 301 |
+
# Cake - Uncomment if you are using it
|
| 302 |
+
# tools/**
|
| 303 |
+
# !tools/packages.config
|
| 304 |
+
|
| 305 |
+
# Tabs Studio
|
| 306 |
+
*.tss
|
| 307 |
+
|
| 308 |
+
# Telerik's JustMock configuration file
|
| 309 |
+
*.jmconfig
|
| 310 |
+
|
| 311 |
+
# BizTalk build output
|
| 312 |
+
*.btp.cs
|
| 313 |
+
*.btm.cs
|
| 314 |
+
*.odx.cs
|
| 315 |
+
*.xsd.cs
|
| 316 |
+
|
| 317 |
+
# OpenCover UI analysis results
|
| 318 |
+
OpenCover/
|
| 319 |
+
|
| 320 |
+
# Azure Stream Analytics local run output
|
| 321 |
+
ASALocalRun/
|
| 322 |
+
|
| 323 |
+
# MSBuild Binary and Structured Log
|
| 324 |
+
*.binlog
|
| 325 |
+
|
| 326 |
+
# NVidia Nsight GPU debugger configuration file
|
| 327 |
+
*.nvuser
|
| 328 |
+
|
| 329 |
+
# MFractors (Xamarin productivity tool) working folder
|
| 330 |
+
.mfractor/
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
# preprocessed data caused by testing
|
| 334 |
+
tests/corpus/*.db
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
# Initially taken from Github's Python gitignore file
|
| 338 |
+
|
| 339 |
+
# Byte-compiled / optimized / DLL files
|
| 340 |
+
__pycache__/
|
| 341 |
+
.idea/
|
| 342 |
+
*.py[cod]
|
| 343 |
+
*$py.class
|
| 344 |
+
|
| 345 |
+
# C extensions
|
| 346 |
+
*.so
|
| 347 |
+
|
| 348 |
+
# Distribution / packaging
|
| 349 |
+
.DS_Store
|
| 350 |
+
.Python
|
| 351 |
+
build/
|
| 352 |
+
develop-eggs/
|
| 353 |
+
dist/
|
| 354 |
+
downloads/
|
| 355 |
+
eggs/
|
| 356 |
+
.eggs/
|
| 357 |
+
lib/
|
| 358 |
+
lib64/
|
| 359 |
+
parts/
|
| 360 |
+
sdist/
|
| 361 |
+
var/
|
| 362 |
+
wheels/
|
| 363 |
+
*.egg-info/
|
| 364 |
+
.installed.cfg
|
| 365 |
+
*.egg
|
| 366 |
+
MANIFEST
|
| 367 |
+
|
| 368 |
+
# PyInstaller
|
| 369 |
+
# Usually these files are written by a python script from a template
|
| 370 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 371 |
+
*.manifest
|
| 372 |
+
*.spec
|
| 373 |
+
|
| 374 |
+
# Installer logs
|
| 375 |
+
pip-log.txt
|
| 376 |
+
pip-delete-this-directory.txt
|
| 377 |
+
|
| 378 |
+
# Unit test / coverage reports
|
| 379 |
+
htmlcov/
|
| 380 |
+
.tox/
|
| 381 |
+
.nox/
|
| 382 |
+
.coverage
|
| 383 |
+
.coverage.*
|
| 384 |
+
.cache
|
| 385 |
+
nosetests.xml
|
| 386 |
+
coverage.xml
|
| 387 |
+
*.cover
|
| 388 |
+
.hypothesis/
|
| 389 |
+
.pytest_cache/
|
| 390 |
+
|
| 391 |
+
# Translations
|
| 392 |
+
*.mo
|
| 393 |
+
*.pot
|
| 394 |
+
|
| 395 |
+
# Django stuff:
|
| 396 |
+
*.log
|
| 397 |
+
local_settings.py
|
| 398 |
+
db.sqlite3
|
| 399 |
+
|
| 400 |
+
# Flask stuff:
|
| 401 |
+
instance/
|
| 402 |
+
.webassets-cache
|
| 403 |
+
|
| 404 |
+
# Scrapy stuff:
|
| 405 |
+
.scrapy
|
| 406 |
+
|
| 407 |
+
# Sphinx documentation
|
| 408 |
+
docs/_build/
|
| 409 |
+
|
| 410 |
+
# PyBuilder
|
| 411 |
+
target/
|
| 412 |
+
|
| 413 |
+
# Jupyter Notebook
|
| 414 |
+
.ipynb_checkpoints
|
| 415 |
+
|
| 416 |
+
# IPython
|
| 417 |
+
profile_default/
|
| 418 |
+
ipython_config.py
|
| 419 |
+
|
| 420 |
+
# pyenv
|
| 421 |
+
.python-version
|
| 422 |
+
|
| 423 |
+
# celery beat schedule file
|
| 424 |
+
celerybeat-schedule
|
| 425 |
+
|
| 426 |
+
# SageMath parsed files
|
| 427 |
+
*.sage.py
|
| 428 |
+
|
| 429 |
+
# unstaged code
|
| 430 |
+
# run_gpt2.py
|
| 431 |
+
|
| 432 |
+
# Environments
|
| 433 |
+
.env
|
| 434 |
+
.venv
|
| 435 |
+
env/
|
| 436 |
+
venv/
|
| 437 |
+
ENV/
|
| 438 |
+
env.bak/
|
| 439 |
+
venv.bak/
|
| 440 |
+
|
| 441 |
+
# Spyder project settings
|
| 442 |
+
.spyderproject
|
| 443 |
+
.spyproject
|
| 444 |
+
|
| 445 |
+
# Rope project settings
|
| 446 |
+
.ropeproject
|
| 447 |
+
|
| 448 |
+
# mkdocs documentation
|
| 449 |
+
/site
|
| 450 |
+
|
| 451 |
+
# mypy
|
| 452 |
+
.mypy_cache/
|
| 453 |
+
.dmypy.json
|
| 454 |
+
dmypy.json
|
| 455 |
+
|
| 456 |
+
# Pyre type checker
|
| 457 |
+
.pyre/
|
| 458 |
+
|
| 459 |
+
# vscode
|
| 460 |
+
.vscode
|
| 461 |
+
|
| 462 |
+
# TF code
|
| 463 |
+
tensorflow_code
|
| 464 |
+
|
| 465 |
+
# Models
|
| 466 |
+
# models
|
| 467 |
+
models/
|
| 468 |
+
# test.txt
|
| 469 |
+
# dstc/
|
| 470 |
+
# scripts/
|
| 471 |
+
# philly/
|
| 472 |
+
# demo/
|
| 473 |
+
# docker/
|
| 474 |
+
# prepro_v4.py
|
| 475 |
+
# prepro.py
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) Microsoft Corporation.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE
|
LSP-generic.yml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: LSP
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- _libgcc_mutex=0.1=main
|
| 7 |
+
- blas=1.0=mkl
|
| 8 |
+
- ca-certificates=2019.5.15=1
|
| 9 |
+
- certifi=2019.6.16=py36_1
|
| 10 |
+
- cffi=1.12.3
|
| 11 |
+
- cudatoolkit=10.0.130=0
|
| 12 |
+
- freetype=2.9.1
|
| 13 |
+
- intel-openmp=2019.4
|
| 14 |
+
- jpeg=9b
|
| 15 |
+
- libpng=1.6.37
|
| 16 |
+
- libtiff=4.0.10
|
| 17 |
+
- mkl=2019.4
|
| 18 |
+
- mkl-service=2.3.0
|
| 19 |
+
- mkl_fft=1.0.14
|
| 20 |
+
- mkl_random=1.0.2
|
| 21 |
+
- ninja=1.9.0
|
| 22 |
+
- numpy=1.16.5
|
| 23 |
+
- numpy-base=1.16.5
|
| 24 |
+
- olefile=0.46=py36_0
|
| 25 |
+
- openssl=1.1.1d
|
| 26 |
+
- pillow=6.1.0
|
| 27 |
+
- pip=19.2.2=py36_0
|
| 28 |
+
- pycparser=2.19=py36_0
|
| 29 |
+
- python=3.6.9
|
| 30 |
+
- pytorch=1.2.0
|
| 31 |
+
- setuptools=41.0.1=py36_0
|
| 32 |
+
- six=1.12.0=py36_0
|
| 33 |
+
- sqlite=3.29.0
|
| 34 |
+
- tk=8.6.8
|
| 35 |
+
- torchvision=0.4.0=py36_cu100
|
| 36 |
+
- wheel=0.33.4=py36_0
|
| 37 |
+
- xz=5.2.4
|
| 38 |
+
- zlib=1.2.11
|
| 39 |
+
- zstd=1.3.7
|
| 40 |
+
- nltk=3.4.1
|
| 41 |
+
- pip:
|
| 42 |
+
- backcall==0.1.0
|
| 43 |
+
- boto3==1.9.228
|
| 44 |
+
- botocore==1.12.228
|
| 45 |
+
- chardet==3.0.4
|
| 46 |
+
- decorator==4.4.0
|
| 47 |
+
- docutils==0.15.2
|
| 48 |
+
- idna==2.8
|
| 49 |
+
- ipython==7.8.0
|
| 50 |
+
- ipython-genutils==0.2.0
|
| 51 |
+
- jedi==0.15.1
|
| 52 |
+
- jmespath==0.9.4
|
| 53 |
+
- parso==0.5.1
|
| 54 |
+
- pexpect==4.7.0
|
| 55 |
+
- pickleshare==0.7.5
|
| 56 |
+
- prompt-toolkit==2.0.9
|
| 57 |
+
- ptyprocess==0.6.0
|
| 58 |
+
- pygments==2.4.2
|
| 59 |
+
- python-dateutil==2.8.0
|
| 60 |
+
- pytorch-pretrained-bert==0.6.1
|
| 61 |
+
- regex==2019.8.19
|
| 62 |
+
- requests==2.22.0
|
| 63 |
+
- s3transfer==0.2.1
|
| 64 |
+
- tqdm==4.35.0
|
| 65 |
+
- traitlets==4.3.2
|
| 66 |
+
- urllib3==1.25.3
|
| 67 |
+
- wcwidth==0.1.7
|
| 68 |
+
- flashtext==2.7
|
| 69 |
+
prefix: /home/JJteam/anaconda3/envs/LSP
|
| 70 |
+
|
LSP-linux.yml
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: LSP
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- _libgcc_mutex=0.1=main
|
| 7 |
+
- blas=1.0=mkl
|
| 8 |
+
- ca-certificates=2019.5.15=1
|
| 9 |
+
- certifi=2019.6.16=py36_1
|
| 10 |
+
- cffi=1.12.3=py36h2e261b9_0
|
| 11 |
+
- cudatoolkit=10.0.130=0
|
| 12 |
+
- freetype=2.9.1=h8a8886c_1
|
| 13 |
+
- intel-openmp=2019.4=243
|
| 14 |
+
- jpeg=9b=h024ee3a_2
|
| 15 |
+
- libedit=3.1.20181209=hc058e9b_0
|
| 16 |
+
- libffi=3.2.1=hd88cf55_4
|
| 17 |
+
- libgcc-ng=9.1.0=hdf63c60_0
|
| 18 |
+
- libgfortran-ng=7.3.0=hdf63c60_0
|
| 19 |
+
- libpng=1.6.37=hbc83047_0
|
| 20 |
+
- libstdcxx-ng=9.1.0=hdf63c60_0
|
| 21 |
+
- libtiff=4.0.10=h2733197_2
|
| 22 |
+
- mkl=2019.4=243
|
| 23 |
+
- mkl-service=2.3.0=py36he904b0f_0
|
| 24 |
+
- mkl_fft=1.0.14=py36ha843d7b_0
|
| 25 |
+
- mkl_random=1.0.2=py36hd81dba3_0
|
| 26 |
+
- ncurses=6.1=he6710b0_1
|
| 27 |
+
- ninja=1.9.0=py36hfd86e86_0
|
| 28 |
+
- numpy=1.16.5=py36h7e9f1db_0
|
| 29 |
+
- numpy-base=1.16.5=py36hde5b4d6_0
|
| 30 |
+
- olefile=0.46=py36_0
|
| 31 |
+
- openssl=1.1.1d=h7b6447c_1
|
| 32 |
+
- pillow=6.1.0=py36h34e0f95_0
|
| 33 |
+
- pip=19.2.2=py36_0
|
| 34 |
+
- pycparser=2.19=py36_0
|
| 35 |
+
- python=3.6.9=h265db76_0
|
| 36 |
+
- pytorch=1.2.0=py3.6_cuda10.0.130_cudnn7.6.2_0
|
| 37 |
+
- readline=7.0=h7b6447c_5
|
| 38 |
+
- setuptools=41.0.1=py36_0
|
| 39 |
+
- six=1.12.0=py36_0
|
| 40 |
+
- sqlite=3.29.0=h7b6447c_0
|
| 41 |
+
- tk=8.6.8=hbc83047_0
|
| 42 |
+
- torchvision=0.4.0=py36_cu100
|
| 43 |
+
- wheel=0.33.4=py36_0
|
| 44 |
+
- xz=5.2.4=h14c3975_4
|
| 45 |
+
- zlib=1.2.11=h7b6447c_3
|
| 46 |
+
- zstd=1.3.7=h0b5b093_0
|
| 47 |
+
- nltk=3.4.1
|
| 48 |
+
- pip:
|
| 49 |
+
- backcall==0.1.0
|
| 50 |
+
- boto3==1.9.228
|
| 51 |
+
- botocore==1.12.228
|
| 52 |
+
- chardet==3.0.4
|
| 53 |
+
- decorator==4.4.0
|
| 54 |
+
- docutils==0.15.2
|
| 55 |
+
- idna==2.8
|
| 56 |
+
- ipython==7.8.0
|
| 57 |
+
- ipython-genutils==0.2.0
|
| 58 |
+
- jedi==0.15.1
|
| 59 |
+
- jmespath==0.9.4
|
| 60 |
+
- parso==0.5.1
|
| 61 |
+
- pexpect==4.7.0
|
| 62 |
+
- pickleshare==0.7.5
|
| 63 |
+
- prompt-toolkit==2.0.9
|
| 64 |
+
- ptyprocess==0.6.0
|
| 65 |
+
- pygments==2.4.2
|
| 66 |
+
- python-dateutil==2.8.0
|
| 67 |
+
- pytorch-pretrained-bert==0.6.1
|
| 68 |
+
- regex==2019.8.19
|
| 69 |
+
- requests==2.22.0
|
| 70 |
+
- s3transfer==0.2.1
|
| 71 |
+
- tqdm==4.35.0
|
| 72 |
+
- traitlets==4.3.2
|
| 73 |
+
- urllib3==1.25.3
|
| 74 |
+
- wcwidth==0.1.7
|
| 75 |
+
- flashtext==2.7
|
| 76 |
+
prefix: /home/JJteam/anaconda3/envs/LSP
|
| 77 |
+
|
LSP_train.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
'''
|
| 4 |
+
* @Desc: train GPT2 from scratch/ fine tuning.
|
| 5 |
+
Modified based on Huggingface GPT-2 implementation
|
| 6 |
+
'''
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import argparse
|
| 12 |
+
import logging
|
| 13 |
+
import time
|
| 14 |
+
import tqdm
|
| 15 |
+
import datetime
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
from os.path import join
|
| 21 |
+
from torch.distributed import get_rank, get_world_size
|
| 22 |
+
|
| 23 |
+
from lsp_model import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, Adam
|
| 24 |
+
from gpt2_training.train_utils import load_model, boolean_string, set_lr, get_eval_list_same_length
|
| 25 |
+
from gpt2_training.eval_utils import eval_model_loss
|
| 26 |
+
|
| 27 |
+
from data_loader import BucketingDataLoader, DynamicBatchingLoader, DistributedBucketingDataLoader
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
from gpt2_training.distributed import all_reduce_and_rescale_tensors, all_gather_list
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logging.basicConfig(
|
| 34 |
+
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
| 35 |
+
datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
|
| 36 |
+
logger = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
INF = 100000000
|
| 39 |
+
CACHE_EMPTY_STEP = 10000
|
| 40 |
+
EVAL_STEP = 100000
|
| 41 |
+
|
| 42 |
+
#########################################################################
|
| 43 |
+
# Prepare Parser
|
| 44 |
+
##########################################################################
|
| 45 |
+
|
| 46 |
+
parser = argparse.ArgumentParser()
|
| 47 |
+
parser.add_argument('--model_name_or_path', type=str,
|
| 48 |
+
help='pretrained model name or path to local checkpoint')
|
| 49 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 50 |
+
parser.add_argument("--max_seq_length", type=int, default=128)
|
| 51 |
+
|
| 52 |
+
parser.add_argument("--skip_eval", action='store_true',
|
| 53 |
+
help='If true, skip evaluation.')
|
| 54 |
+
parser.add_argument("--init_checkpoint", type=str)
|
| 55 |
+
parser.add_argument("--train_input_file", type=str)
|
| 56 |
+
parser.add_argument("--eval_input_file", type=str)
|
| 57 |
+
parser.add_argument("--continue_from", type=int, default=0)
|
| 58 |
+
|
| 59 |
+
parser.add_argument("--train_batch_size", type=int, default=4,
|
| 60 |
+
help="batch size now means per GPU per step")
|
| 61 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=2,
|
| 62 |
+
help="to increase effective batch size "
|
| 63 |
+
"and reduce synchronization")
|
| 64 |
+
parser.add_argument("--eval_batch_size", type=int, default=4)
|
| 65 |
+
parser.add_argument("--learning_rate", type=float, default=1e-5)
|
| 66 |
+
parser.add_argument("--num_optim_steps", type=int, default=1000000,
|
| 67 |
+
help="new API specifies num update steps")
|
| 68 |
+
parser.add_argument("--valid_step", type=int, default=10000,
|
| 69 |
+
help="how many optim steps between validations")
|
| 70 |
+
parser.add_argument("--warmup_proportion", type=float, default=0.1)
|
| 71 |
+
parser.add_argument("--warmup_steps", type=int, default=16000)
|
| 72 |
+
|
| 73 |
+
parser.add_argument("--normalize_data", type=boolean_string, default=True)
|
| 74 |
+
parser.add_argument("--fp16", type=boolean_string, default=True)
|
| 75 |
+
parser.add_argument("--lr_schedule", type=str,
|
| 76 |
+
choices=['noam', 'noamwd', 'BERT', 'None'], default='noam')
|
| 77 |
+
parser.add_argument("--loss_scale", type=float, default=0)
|
| 78 |
+
parser.add_argument("--no_token_id", type=boolean_string, default=True)
|
| 79 |
+
|
| 80 |
+
parser.add_argument("--output_dir", type=str)
|
| 81 |
+
parser.add_argument("--log_dir", type=str)
|
| 82 |
+
parser.add_argument('--pbar', type=boolean_string, default=True, help='turn on progress bar')
|
| 83 |
+
|
| 84 |
+
# distributed
|
| 85 |
+
parser.add_argument('--local_rank', type=int, default=-1,
|
| 86 |
+
help='for torch.distributed')
|
| 87 |
+
parser.add_argument('--config', help='JSON config file')
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# do normal parsing
|
| 91 |
+
args = parser.parse_args()
|
| 92 |
+
|
| 93 |
+
if args.config is not None:
|
| 94 |
+
# override argparse defaults by config JSON
|
| 95 |
+
opts = json.load(open(args.config))
|
| 96 |
+
for k, v in opts.items():
|
| 97 |
+
if isinstance(v, str):
|
| 98 |
+
# PHILLY ENV special cases
|
| 99 |
+
if 'PHILLY_JOB_DIRECTORY' in v:
|
| 100 |
+
v = v.replace('PHILLY_JOB_DIRECTORY',
|
| 101 |
+
os.environ['PHILLY_JOB_DIRECTORY'])
|
| 102 |
+
elif 'PHILLY_LOG_DIRECTORY' in v:
|
| 103 |
+
v = v.replace('PHILLY_LOG_DIRECTORY',
|
| 104 |
+
os.environ['PHILLY_LOG_DIRECTORY'])
|
| 105 |
+
setattr(args, k, v)
|
| 106 |
+
|
| 107 |
+
# command line should override config JSON
|
| 108 |
+
argv = sys.argv[1:]
|
| 109 |
+
overrides, _ = parser.parse_known_args(argv)
|
| 110 |
+
for k, v in vars(overrides).items():
|
| 111 |
+
if f'--{k}' in argv:
|
| 112 |
+
setattr(args, k, v)
|
| 113 |
+
setattr(args, 'local_rank', overrides.local_rank)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
assert args.train_batch_size % args.gradient_accumulation_steps == 0, \
|
| 117 |
+
'batch size % gradient accumulation steps != 0!'
|
| 118 |
+
args.train_batch_size = (args.train_batch_size
|
| 119 |
+
// args.gradient_accumulation_steps)
|
| 120 |
+
logger.info('train batch size = {}, '
|
| 121 |
+
'new train batch size (after gradient accumulation) = {}'.format(
|
| 122 |
+
args.train_batch_size*args.gradient_accumulation_steps,
|
| 123 |
+
args.train_batch_size))
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if args.local_rank == -1:
|
| 127 |
+
logger.info('CUDA available? {}'.format(str(torch.cuda.is_available())))
|
| 128 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 129 |
+
n_gpu = torch.cuda.device_count()
|
| 130 |
+
args.device, args.n_gpu = device, n_gpu
|
| 131 |
+
else:
|
| 132 |
+
# distributed training
|
| 133 |
+
torch.cuda.set_device(args.local_rank)
|
| 134 |
+
device = torch.device("cuda", args.local_rank)
|
| 135 |
+
# Initializes the distributed backend which will take care of
|
| 136 |
+
# sychronizing nodes/GPUs
|
| 137 |
+
torch.distributed.init_process_group(backend='nccl')
|
| 138 |
+
n_gpu = torch.distributed.get_world_size()
|
| 139 |
+
args.device, args.n_gpu = device, 1
|
| 140 |
+
logger.info("device: {} n_gpu: {}, distributed training: {}, "
|
| 141 |
+
"16-bits training: {}".format(
|
| 142 |
+
device, n_gpu, bool(args.local_rank != -1), args.fp16))
|
| 143 |
+
|
| 144 |
+
np.random.seed(args.seed)
|
| 145 |
+
torch.random.manual_seed(args.seed)
|
| 146 |
+
torch.cuda.manual_seed(args.seed)
|
| 147 |
+
if n_gpu > 0:
|
| 148 |
+
torch.cuda.manual_seed_all(args.seed)
|
| 149 |
+
|
| 150 |
+
timestamp = datetime.datetime.now().strftime('%Y-%m-%d%H%M%S')
|
| 151 |
+
output_dir = join(args.output_dir,
|
| 152 |
+
'GPT2.{}.{}.{}gpu.{}'.format(args.learning_rate,
|
| 153 |
+
args.train_batch_size, n_gpu,
|
| 154 |
+
timestamp))
|
| 155 |
+
log_dir = args.log_dir if args.log_dir is not None and len(args.log_dir) > 0 else output_dir
|
| 156 |
+
if args.local_rank == -1 or get_rank() == 0:
|
| 157 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 158 |
+
|
| 159 |
+
logger.info('Input Argument Information')
|
| 160 |
+
args_dict = vars(args)
|
| 161 |
+
for a in args_dict:
|
| 162 |
+
logger.info('%-28s %s' % (a, args_dict[a]))
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
#########################################################################
|
| 166 |
+
# Prepare Data Set
|
| 167 |
+
##########################################################################
|
| 168 |
+
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
|
| 169 |
+
|
| 170 |
+
config = GPT2Config.from_json_file(
|
| 171 |
+
join(args.model_name_or_path, 'config.json'))
|
| 172 |
+
|
| 173 |
+
if args.local_rank == -1:
|
| 174 |
+
train_dataloader = BucketingDataLoader(args.train_input_file,
|
| 175 |
+
args.train_batch_size,
|
| 176 |
+
args.max_seq_length)
|
| 177 |
+
else:
|
| 178 |
+
train_dataloader = DistributedBucketingDataLoader(
|
| 179 |
+
get_rank(), get_world_size(),
|
| 180 |
+
args.train_input_file, args.train_batch_size,
|
| 181 |
+
args.max_seq_length)
|
| 182 |
+
|
| 183 |
+
eval_dataloader_loss = DynamicBatchingLoader(
|
| 184 |
+
args.eval_input_file, enc, args.normalize_data,
|
| 185 |
+
args.eval_batch_size, args.max_seq_length)
|
| 186 |
+
|
| 187 |
+
eval_dataloader_gen = get_eval_list_same_length(
|
| 188 |
+
args.eval_input_file, enc, args.eval_batch_size, True)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
#########################################################################
|
| 192 |
+
# Prepare Model and Optimizer
|
| 193 |
+
##########################################################################
|
| 194 |
+
model = load_model(GPT2LMHeadModel(config), args.init_checkpoint,
|
| 195 |
+
args, verbose=True)
|
| 196 |
+
if args.local_rank != -1:
|
| 197 |
+
# when from scratch make sure initial models are the same
|
| 198 |
+
params = [p.data for p in model.parameters()]
|
| 199 |
+
all_reduce_and_rescale_tensors(
|
| 200 |
+
params, float(torch.distributed.get_world_size()))
|
| 201 |
+
|
| 202 |
+
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
| 203 |
+
total_params = sum([np.prod(p.size()) for p in model_parameters])
|
| 204 |
+
logger.info('Number of parameter = {}'.format(total_params))
|
| 205 |
+
|
| 206 |
+
param_optimizer = list(model.named_parameters())
|
| 207 |
+
no_decay = ['bias', 'ln'] # no decay for bias and LayerNorm (ln)
|
| 208 |
+
optimizer_grouped_parameters = [
|
| 209 |
+
{'params': [p for n, p in param_optimizer
|
| 210 |
+
if not any(nd in n for nd in no_decay)],
|
| 211 |
+
'weight_decay': 0.01},
|
| 212 |
+
{'params': [p for n, p in param_optimizer
|
| 213 |
+
if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
| 214 |
+
]
|
| 215 |
+
|
| 216 |
+
if args.fp16:
|
| 217 |
+
logger.info('in fp16, using FusedAdam')
|
| 218 |
+
try:
|
| 219 |
+
from apex.optimizers import FP16_Optimizer
|
| 220 |
+
from apex.optimizers import FusedAdam
|
| 221 |
+
except ImportError:
|
| 222 |
+
raise ImportError(
|
| 223 |
+
"Please install apex from https://www.github.com/nvidia/apex "
|
| 224 |
+
"to use distributed and fp16 training.")
|
| 225 |
+
|
| 226 |
+
optimizer = FusedAdam(optimizer_grouped_parameters,
|
| 227 |
+
lr=args.learning_rate,
|
| 228 |
+
bias_correction=False,
|
| 229 |
+
max_grad_norm=1.0)
|
| 230 |
+
if args.loss_scale == 0:
|
| 231 |
+
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True,
|
| 232 |
+
verbose=False)
|
| 233 |
+
else:
|
| 234 |
+
optimizer = FP16_Optimizer(optimizer,
|
| 235 |
+
static_loss_scale=args.loss_scale,
|
| 236 |
+
verbose=False)
|
| 237 |
+
else:
|
| 238 |
+
optimizer = Adam(optimizer_grouped_parameters, args.learning_rate,
|
| 239 |
+
max_grad_norm=1.0)
|
| 240 |
+
|
| 241 |
+
#########################################################################
|
| 242 |
+
# Training !
|
| 243 |
+
##########################################################################
|
| 244 |
+
|
| 245 |
+
if args.local_rank == -1 or get_rank() == 0:
|
| 246 |
+
train_logger = open(join(log_dir, 'train_log.txt'), 'a+', buffering=1)
|
| 247 |
+
eval_logger = open(join(log_dir, 'eval_log.txt'), 'a+', buffering=1)
|
| 248 |
+
print('epoch,global_step,step,mean_loss,mean_ppl,n_token_real,'
|
| 249 |
+
'n_token_total,epoch_time', file=train_logger)
|
| 250 |
+
print('epoch,global_step,step,eval_loss,eval_ppl', file=eval_logger)
|
| 251 |
+
|
| 252 |
+
global_step = 0
|
| 253 |
+
step = 0
|
| 254 |
+
epoch = 0
|
| 255 |
+
|
| 256 |
+
if args.continue_from:
|
| 257 |
+
global_step = args.continue_from
|
| 258 |
+
step = global_step*2 - 1
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
if args.local_rank != -1:
|
| 262 |
+
n_gpu = 1
|
| 263 |
+
if args.local_rank == -1 or get_rank() == 0:
|
| 264 |
+
if args.pbar:
|
| 265 |
+
pbar = tqdm.tqdm(total=args.num_optim_steps, desc=f"training")
|
| 266 |
+
else:
|
| 267 |
+
pbar = None
|
| 268 |
+
|
| 269 |
+
while True:
|
| 270 |
+
model.train()
|
| 271 |
+
(tr_loss, tr_ppl, mean_ppl, nb_tr_examples, nb_tr_steps) = 0.0, 0.0, 0.0, 0, 0
|
| 272 |
+
n_token_real, n_token_total = 0, 0
|
| 273 |
+
train_start_time_epoch = time.time()
|
| 274 |
+
for batch in train_dataloader:
|
| 275 |
+
# activate new training mode
|
| 276 |
+
seq_len = batch[0].shape[1]
|
| 277 |
+
batch = tuple(t.to(device) for t in batch)
|
| 278 |
+
input_ids, position_ids, token_ids, label_ids, *_ = batch
|
| 279 |
+
if args.no_token_id:
|
| 280 |
+
token_ids = None
|
| 281 |
+
loss, ppl = model(input_ids, position_ids, token_ids, label_ids)
|
| 282 |
+
|
| 283 |
+
if n_gpu > 1:
|
| 284 |
+
loss = loss.mean()
|
| 285 |
+
ppl = ppl.mean()
|
| 286 |
+
loss = loss / (args.train_batch_size / input_ids.shape[0])
|
| 287 |
+
if args.fp16:
|
| 288 |
+
optimizer.backward(loss)
|
| 289 |
+
else:
|
| 290 |
+
loss.backward()
|
| 291 |
+
|
| 292 |
+
tr_loss += float(loss.item()) * (args.train_batch_size / input_ids.shape[0])
|
| 293 |
+
nb_tr_examples += input_ids.size(0)
|
| 294 |
+
nb_tr_steps += 1
|
| 295 |
+
mean_loss = tr_loss / nb_tr_steps
|
| 296 |
+
if ppl.item() < INF:
|
| 297 |
+
tr_ppl += ppl.item()
|
| 298 |
+
else:
|
| 299 |
+
tr_ppl += mean_ppl
|
| 300 |
+
mean_ppl = tr_ppl / nb_tr_steps
|
| 301 |
+
|
| 302 |
+
n_token_total += input_ids.shape[0] * input_ids.shape[1]
|
| 303 |
+
n_token_real += (input_ids != 0).sum().item()
|
| 304 |
+
|
| 305 |
+
# gradient update
|
| 306 |
+
step += 1
|
| 307 |
+
if step % args.gradient_accumulation_steps == 0:
|
| 308 |
+
set_lr(optimizer, global_step,
|
| 309 |
+
args.lr_schedule, args.learning_rate,
|
| 310 |
+
args.warmup_steps, args.warmup_proportion,
|
| 311 |
+
config.n_embd, args.num_optim_steps)
|
| 312 |
+
|
| 313 |
+
if args.local_rank != -1:
|
| 314 |
+
grads = [p.grad.data for p in model.parameters()
|
| 315 |
+
if p.requires_grad and p.grad is not None]
|
| 316 |
+
all_reduce_and_rescale_tensors(grads, float(1))
|
| 317 |
+
|
| 318 |
+
optimizer.step()
|
| 319 |
+
optimizer.zero_grad()
|
| 320 |
+
global_step += 1
|
| 321 |
+
|
| 322 |
+
# Print log info to file
|
| 323 |
+
if args.local_rank != -1:
|
| 324 |
+
mean_loss = sum(all_gather_list(mean_loss)) / get_world_size()
|
| 325 |
+
mean_ppl = sum(all_gather_list(mean_ppl)) / get_world_size()
|
| 326 |
+
n_token_real_all_proc = sum(all_gather_list(n_token_real))
|
| 327 |
+
n_token_total_all_proc = sum(all_gather_list(n_token_total))
|
| 328 |
+
else:
|
| 329 |
+
n_token_real_all_proc = n_token_real
|
| 330 |
+
n_token_total_all_proc = n_token_total
|
| 331 |
+
|
| 332 |
+
if args.local_rank == -1 or get_rank() == 0:
|
| 333 |
+
epoch_time = time.time() - train_start_time_epoch
|
| 334 |
+
if pbar is not None:
|
| 335 |
+
pbar.set_postfix_str(
|
| 336 |
+
f"tok/s: {n_token_real_all_proc//epoch_time//1000}k "
|
| 337 |
+
f"ppl: {mean_ppl:.2f} epoch: {epoch}")
|
| 338 |
+
pbar.update(1)
|
| 339 |
+
print('{},{},{},{},{},{},{},{}'.format(
|
| 340 |
+
epoch+1, global_step+1, step+1, mean_loss, mean_ppl,
|
| 341 |
+
n_token_real_all_proc, n_token_total_all_proc, epoch_time),
|
| 342 |
+
file=train_logger)
|
| 343 |
+
|
| 344 |
+
if global_step % args.valid_step == 0:
|
| 345 |
+
if args.local_rank == -1 or get_rank() == 0:
|
| 346 |
+
# only rank 0 process evaluate
|
| 347 |
+
torch.save(
|
| 348 |
+
{k: (v.cpu() if v is not None else None) # save to cpu tensors
|
| 349 |
+
for k, v in model.state_dict().items()},
|
| 350 |
+
join(output_dir,
|
| 351 |
+
f'GP2-pretrain-step-{global_step}.pkl'))
|
| 352 |
+
|
| 353 |
+
eval_loss, eval_ppl = eval_model_loss(
|
| 354 |
+
model, enc, eval_dataloader_loss, epoch, args)
|
| 355 |
+
# enable generation step evaluation for now
|
| 356 |
+
# gen_response = eval_model_generation(
|
| 357 |
+
# model, enc, eval_dataloader_gen, epoch, args)
|
| 358 |
+
'''
|
| 359 |
+
# probably use beam search only for test set
|
| 360 |
+
if False:
|
| 361 |
+
gen_response_beam = eval_model_generation(
|
| 362 |
+
model, enc, eval_dataloader_gen, epoch, args,
|
| 363 |
+
use_beam_search=True, beam_width=3)
|
| 364 |
+
'''
|
| 365 |
+
print('{},{},{},{},{}'.format(
|
| 366 |
+
epoch+1, global_step+1, step+1, eval_loss, eval_ppl),
|
| 367 |
+
file=eval_logger)
|
| 368 |
+
logger.info('current learning rate: '
|
| 369 |
+
+ str(optimizer.param_groups[0]['lr']))
|
| 370 |
+
model.train()
|
| 371 |
+
if global_step >= args.num_optim_steps:
|
| 372 |
+
break
|
| 373 |
+
|
| 374 |
+
if (step+1) % CACHE_EMPTY_STEP == 0:
|
| 375 |
+
torch.cuda.empty_cache()
|
| 376 |
+
|
| 377 |
+
if global_step >= args.num_optim_steps:
|
| 378 |
+
break
|
| 379 |
+
epoch += 1
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
if args.local_rank == -1 or get_rank() == 0:
|
| 383 |
+
if pbar is not None:
|
| 384 |
+
pbar.close()
|
| 385 |
+
train_logger.close()
|
| 386 |
+
eval_logger.close()
|
MANIFEST.in
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
include LICENSE
|
README.md
CHANGED
|
@@ -1,16 +1,425 @@
|
|
| 1 |
-
|
| 2 |
-
thumbnail: https://huggingface.co/front/thumbnails/dialogpt.png
|
| 3 |
-
tags:
|
| 4 |
-
- conversational
|
| 5 |
-
license: mit
|
| 6 |
-
---
|
| 7 |
|
| 8 |
-
##
|
| 9 |
|
| 10 |
-
|
| 11 |
-
The [human evaluation results](https://github.com/dreasysnail/Dialogpt_dev#human-evaluation) indicate that the response generated from DialoGPT is comparable to human response quality under a single-turn conversation Turing test.
|
| 12 |
-
The model is trained on 147M multi-turn dialogue from Reddit discussion thread.
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
* Multi-turn generation examples from an interactive environment:
|
| 15 |
|
| 16 |
|Role | Response |
|
|
@@ -22,33 +431,186 @@ The model is trained on 147M multi-turn dialogue from Reddit discussion thread.
|
|
| 22 |
|User |This is so difficult ! |
|
| 23 |
| Bot | You have no idea how hard it is to be a millionaire and happy . There is a reason the rich have a lot of money |
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
ArXiv paper: [https://arxiv.org/abs/1911.00536](https://arxiv.org/abs/1911.00536)
|
| 28 |
|
| 29 |
-
### How to use
|
| 30 |
|
| 31 |
-
Now we are ready to try out how the model works as a chatting partner!
|
| 32 |
|
| 33 |
-
```python
|
| 34 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 35 |
-
import torch
|
| 36 |
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
-
for step in range(5):
|
| 43 |
-
# encode the new user input, add the eos_token and return a tensor in Pytorch
|
| 44 |
-
new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A State-of-the-Art Large-scale Pretrained Response Generation Model (DialoGPT)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
## This project page is no longer maintained as DialoGPT is superseded by [GODEL](https://github.com/microsoft/GODEL), which outperforms DialoGPT according to the results of [this paper](https://arxiv.org/pdf/2206.11309.pdf). Unless you use DialoGPT for reproducibility reasons, we highly recommend you switch to [GODEL](https://github.com/microsoft/GODEL).
|
| 4 |
|
| 5 |
+
This repository contains the source code and trained model for a large-scale pretrained dialogue response generation model. The [human evaluation results](#human_eval) indicate that the response generated from DialoGPT is comparable to human response quality under a single-turn conversation Turing test.
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
<!--See more details on our [project page](https://www.microsoft.com/en-us/research/project/large-scale-pretraining-for-response-generation/)-->
|
| 8 |
+
|
| 9 |
+
The repository is based on [huggingface pytorch-transformer](https://github.com/huggingface/transfer-learning-conv-ai) and [OpenAI GPT-2](https://github.com/openai/gpt-2), containing data extraction script, model training code and pretrained small (117M) medium (345M) and large (762M) model checkpoint.
|
| 10 |
+
|
| 11 |
+
The model is trained on 147M multi-turn dialogue from Reddit discussion thread. The largest model can be trained in several hours on a 8 V100 machines (however this is not required), with distributed training and FP16 option.
|
| 12 |
+
|
| 13 |
+
The include script can be used to reproduce the results of DSTC-7 grounded dialogue generation challenge and a 6k multi-reference dataset created from Reddit data.
|
| 14 |
+
|
| 15 |
+
Project webpage: [https://www.microsoft.com/en-us/research/project/large-scale-pretraining-for-response-generation/](https://www.microsoft.com/en-us/research/project/large-scale-pretraining-for-response-generation/)
|
| 16 |
+
|
| 17 |
+
ArXiv paper: [https://arxiv.org/abs/1911.00536](https://arxiv.org/abs/1911.00536)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
## News ##
|
| 21 |
+
|
| 22 |
+
***(Update 07/09/2022) Changes on the files.pushshift.io/reddit server caused our data generation pipeline to break. These problems have now been fixed, and the steps explained in the Data Preparation subsection below should work again. Data is generated in about 10 hours with 8 processes (`-j 8`), and 800GB of temporary disk space is needed.***
|
| 23 |
+
|
| 24 |
+
***(Update 06/23/2021) We have released a retrieval-augmented/grounded version of DialoGPT (RetGen), please check out the [RetGen repo](https://github.com/dreasysnail/RetGen) and [RetGen paper](https://arxiv.org/abs/2105.06597)***
|
| 25 |
+
|
| 26 |
+
***(Update 05/20/2021) An awesome [video walkthrough](https://www.youtube.com/watch?v=Zo679MYoJns) on YouTube for DialoGPT by [Prakhar Mishra](http://wsl.iiitb.ac.in/prakhar-mishra/)***
|
| 27 |
+
|
| 28 |
+
***(Update 03/31/2021) A 3rd party demo by [AK391](https://github.com/AK391) using Gradio [web demo](https://gradio.app/g/AK391/DialoGPT) try it out***
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
***(Update 09/15/2020) A set of large-scale [dialog ranking models](https://github.com/golsun/DialogRPT) has been released!***
|
| 32 |
+
|
| 33 |
+
DialoGPT generation is improved by integrating with our latest dialog ranking models, [DialogRPT](https://github.com/golsun/DialogRPT)
|
| 34 |
+
|
| 35 |
+
***(Update 07/08/2020) The 6K multi-ref test set has been released!***
|
| 36 |
+
|
| 37 |
+
To generate the data, pleaser run `demo.py` and set the data option to 'full', the generated 6k multi-ref test set will be located at
|
| 38 |
+
|
| 39 |
+
`./data/test.refs.txt`
|
| 40 |
+
|
| 41 |
+
***(Update 03/10/2020) Model cards available in Huggingface Transformers!***
|
| 42 |
+
|
| 43 |
+
Please check out our model cards in huggingface Transformers repository. With several lines of code it should be pretty straighforward to play with the DialoGPT interactively.
|
| 44 |
+
|
| 45 |
+
[small model: https://huggingface.co/microsoft/DialoGPT-small](https://huggingface.co/microsoft/DialoGPT-small)
|
| 46 |
+
|
| 47 |
+
[medium model: https://huggingface.co/microsoft/DialoGPT-medium](https://huggingface.co/microsoft/DialoGPT-medium)
|
| 48 |
+
|
| 49 |
+
[large model: https://huggingface.co/microsoft/DialoGPT-large](https://huggingface.co/microsoft/DialoGPT-large)
|
| 50 |
+
|
| 51 |
+
[**(New)** Ranking model: https://huggingface.co/microsoft/DialogRPT-updown](https://huggingface.co/microsoft/DialogRPT-updown?text=I+love+NLP%21+%3C%7Cendoftext%7C%3E+Me+too%21)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
***(Update 01/06/2020) Some third-party decoding script implementations:***
|
| 55 |
+
|
| 56 |
+
- [https://github.com/polakowo/gpt2bot](https://github.com/polakowo/gpt2bot) GPT2Bot implementation based on telegram by polakowo, [ref](https://github.com/microsoft/DialoGPT/issues/3#issuecomment-573904419)
|
| 57 |
+
- [https://colab.research.google.com/drive/1PslHE4Rl4RqSa20s7HEp0ZKITBir6ezE](https://colab.research.google.com/drive/1PslHE4Rl4RqSa20s7HEp0ZKITBir6ezE) A colab interactive notebook by qywu,[ref](https://github.com/microsoft/DialoGPT/issues/3#issuecomment-551410203)
|
| 58 |
+
- [https://github.com/andreamad8/DialoGPT2-Interact](https://github.com/andreamad8/DialoGPT2-Interact) An interactive script featuring multiturn chatbot by andreamad8,[ref](https://github.com/microsoft/DialoGPT/issues/3#issuecomment-551450016)
|
| 59 |
+
- [https://github.com/LHolten/DialoGTP-MMI-decoder](https://github.com/LHolten/DialoGTP-MMI-decoder) An MMI implementation by LHolten,[ref](https://github.com/microsoft/DialoGPT/issues/3#issuecomment-558318401)
|
| 60 |
+
- [https://colab.research.google.com/drive/1-_KjlAV3J1IVDw_9KogjKDCzgFY7Jp7E](https://colab.research.google.com/drive/1-_KjlAV3J1IVDw_9KogjKDCzgFY7Jp7E) A colab interactive notebook by illuminascent@Reddit,[ref](https://www.reddit.com/r/MachineLearning/comments/dt5woy/p_dialogpt_state_of_the_art_conversational_model/?st=k530k3oo&sh=f6cd20fd)
|
| 61 |
+
- [https://colab.research.google.com/drive/15wa925dj7jvdvrz8_z3vU7btqAFQLVlG](https://colab.research.google.com/drive/15wa925dj7jvdvrz8_z3vU7btqAFQLVlG) A great tutorial of how to finetune DialoGPT to build a customized bot built by [Rostyslav Neskorozhenyi](https://www.linkedin.com/in/slanj/). [ref](https://towardsdatascience.com/make-your-own-rick-sanchez-bot-with-transformers-and-dialogpt-fine-tuning-f85e6d1f4e30)
|
| 62 |
+
- [https://gradio.app/g/AK391/DialoGPT](https://gradio.app/g/AK391/DialoGPT) A 3rd party demo by [AK391](https://github.com/AK391) using Gradio [web demo](https://gradio.app/g/AK391/DialoGPT)
|
| 63 |
+
|
| 64 |
+
<!--**This github repository will be updated soon. Please stay tuned.**-->
|
| 65 |
+
<!--## Minimal Computational Configurations-->
|
| 66 |
+
## Recommended Configuration
|
| 67 |
+
|
| 68 |
+
- Linux Ubuntu 16.04
|
| 69 |
+
- GPU with at least 12G memory
|
| 70 |
+
|
| 71 |
+
DialoGPT was developed entirely on **Ubuntu 16.04**, and -- depending on our availability -- we try to provide support if you experience difficulties running the code on the same configuration. However, we are **unable to provide support for other distributions or operating systems**. Portions of the code may run on other UNIX flavors (macOS, Windows subsystem for Linux, Cygwin, etc.), but it is recommended to use Ubuntu for the main training code.
|
| 72 |
+
|
| 73 |
+
The training code can be run on CPU, but it can be slow. We would recommend to use GPU to train and finetune all models. There is no minimal limit of the number of GPUs. However, if using distributed train for multiple GPUs configuration, the speed-up vs the number of GPUs is roughly sub-linear. To simulate the same batchsize when using less GPUs, please use a larger `gradient_accumulation_steps` in model training.
|
| 74 |
+
|
| 75 |
+
The 117M and 345M model can be loaded in a single GPU with 12G memory. The 762M model would require a single GPU that has greater than 16G memory for efficient training. The training speed on a benchmark data with 50M training instances and V100 GPUs:
|
| 76 |
+
|
| 77 |
+
| n\_gpu | epoch time (h) | token/sec |
|
| 78 |
+
|----------------------|--------|--------|
|
| 79 |
+
| 1 | 118 | 10847 |
|
| 80 |
+
| 2 | 62 | 20645 |
|
| 81 |
+
| 4 | 34 | 37647 |
|
| 82 |
+
| 8 | 18 | 71356 |
|
| 83 |
+
|
| 84 |
+
Fine-tuning from our pretrained model on a new dataset typically requires 1-2 epochs.
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
## Setup & Installation (TL;DR)
|
| 88 |
+
|
| 89 |
+
We created a demo script `demo.py` to ease the difficulty of the deployment of this system. The `demo.py` contains a pipeline of **model downloading**, data extraction, data preprocessing and model training over a dummy dataset within one commandline.
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Train model with Conda Environment
|
| 94 |
+
|
| 95 |
+
Please use the below commandlines to clone, install the requirements and load the Conda environment (Note that the Nvidia CUDA 10.0 developer toolkit is required):
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
```bash
|
| 99 |
+
sudo apt-get install -y make wget gzip bzip2 xz-utils zstd sed
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
```bash
|
| 103 |
+
git clone https://github.com/microsoft/DialoGPT.git
|
| 104 |
+
cd DialoGPT
|
| 105 |
+
conda env create -f LSP-linux.yml -n LSP
|
| 106 |
+
conda activate LSP
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
If you run this on an architecture other than Linux, please use `LSP-generic.yml` instead of `LSP-linux.yml` but please note that the generic one is not tested in all platform, so the stablity can not be gauranteed.
|
| 110 |
+
To use fp16 training, please install apex by using commands below
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
conda activate LSP
|
| 114 |
+
git clone https://github.com/NVIDIA/apex
|
| 115 |
+
cd apex
|
| 116 |
+
git reset --hard 3d01e4a0a188cc8df54bc6e44cf5eb40ff6b4cc5
|
| 117 |
+
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
|
| 118 |
+
python3.6 demo.py
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
#### Train model with Docker environment
|
| 122 |
+
To start, first install the docker and Nvidia-docker from their official repos.
|
| 123 |
+
The image environment for running the code can be loaded as below:
|
| 124 |
+
|
| 125 |
+
*Nvidia-docker v2.**
|
| 126 |
+
|
| 127 |
+
```bash
|
| 128 |
+
$ docker run --gpus all --ipc=host --rm -it -v $PWD:/workspace --network=host icaruszyz/large-scale-training:dialogpt bash
|
| 129 |
+
```
|
| 130 |
+
*Nvidia-docker v1.**
|
| 131 |
+
|
| 132 |
+
```bash
|
| 133 |
+
$ nvidia-docker --rm -it -v $PWD:/workspace --network=host icaruszyz/large-scale-training:dialogpt bash
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
Inside the docker container, run
|
| 137 |
+
|
| 138 |
+
```bash
|
| 139 |
+
python demo.py
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
## Pipeline details
|
| 145 |
+
|
| 146 |
+
This section explains all components in the `demo.py`.
|
| 147 |
+
|
| 148 |
+
#### Data loading
|
| 149 |
+
Before running `demo.py`, you can set *DATA_FOLDER* (default value `./models`) in `demo.py` as the place you want to download all the data and pretrained/fine-tuned models. Then simply run
|
| 150 |
+
```bash
|
| 151 |
+
python demo.py
|
| 152 |
+
```
|
| 153 |
+
to
|
| 154 |
+
|
| 155 |
+
* automatically download models and data,
|
| 156 |
+
* prepare raw data into db that is ready to use for the program,
|
| 157 |
+
* generate a training scripts.
|
| 158 |
+
|
| 159 |
+
Note that by default the `demo.py` will use a dummy data, please specify the Reddit training data by using option `--data`. Three options are available:`dummy`,`small` and `full`.
|
| 160 |
+
|
| 161 |
+
```bash
|
| 162 |
+
python demo.py --data small
|
| 163 |
+
python demo.py --data full
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
The small Reddit data is around 140MB and the full Reddit data is more than 27GB. You can prepare a cup of coffee when processing with the full Reddit data because **it takes a long time**!
|
| 167 |
+
|
| 168 |
+
To generate the 6k multi-ref test set data, pleaser run `demo.py` and set the data option to 'full', the generation will be located at
|
| 169 |
+
|
| 170 |
+
`./data/test.refs.txt`
|
| 171 |
+
|
| 172 |
+
#### Pretrained model
|
| 173 |
+
|
| 174 |
+
The pretrained and fine-tuned models are available on azure blobstorage.
|
| 175 |
+
Please run/see `demo.py` for more details about how to download/use those models. Or you could download directly by using the links in `demo_utils.py`.
|
| 176 |
+
|
| 177 |
+
#### Preparing data
|
| 178 |
+
First, use the `prepare4db.sh` to convert a tsv data file into the correct format that the following script can recognize.
|
| 179 |
+
The trainig data need to be then processed into a database file with below commandline:
|
| 180 |
+
|
| 181 |
+
```bash
|
| 182 |
+
python prepro.py --corpus $DATA_PATH
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
#### Using the training script
|
| 188 |
+
|
| 189 |
+
The training script can be used in single GPU or multiple GPU settings (distributed training across multiple GPUs within a single node):
|
| 190 |
+
|
| 191 |
+
```bash
|
| 192 |
+
python ./LSP_train.py # Single GPU training
|
| 193 |
+
python -m torch.distributed.launch --nproc_per_node=8 ./LSP_train.py # Training on 8 GPUs
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
The training script accept several arguments to tweak the training:
|
| 198 |
+
|
| 199 |
+
Argument | Type | Default value | Description
|
| 200 |
+
---------|------|---------------|------------
|
| 201 |
+
max\_seq\_length | `int` | `128` | Maximum number of tokens for each training instance.
|
| 202 |
+
train\_input\_file | `str` | `""` | Path of the training dataset in a .db format
|
| 203 |
+
eval\_input\_file | `str` | `""` | Path of the validation set in a tsv format
|
| 204 |
+
continue_from | `int` | `0` | Resuming the training after a specified number of steps
|
| 205 |
+
fp16 | `boolean` | `True` | Whether to use 16-bits floating point for model training.
|
| 206 |
+
train\_batch\_size | `int` | `4` | Batch size for training
|
| 207 |
+
valid\_batch\_size | `int` | `4` | Batch size for validation
|
| 208 |
+
gradient\_accumulation\_steps | `int` | `2` | Accumulate gradients on several steps
|
| 209 |
+
learning\_rate | `float` | `1e-5` | Learning rate
|
| 210 |
+
lr\_schedule | `str` | `noam` | Learning rate schedule can be chosen from [`noam`, `noamwd`, `BERT`, `None`]
|
| 211 |
+
num\_optim\_steps | `int` | `1000000` | Number of training optimization steps
|
| 212 |
+
no_token_id | `boolean` | `True` | If set True, using all-zeros token-type embedding.
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
During the training, two log files will be updated. The `train_log.txt` and `eval_log.txt` contains the model loss, perplexity and training speed (tokens/sec) statistics for the training and dev set.
|
| 216 |
+
|
| 217 |
+
The log file and saved model checkpoint can be found in `./models/output_model`
|
| 218 |
+
|
| 219 |
+
#### Model decoding
|
| 220 |
+
We note that even with properly filtered Reddit dataset, sometimes our model can still generate moderately toxic/inappropriate responses. Due to this reason, we are unable to provide the decoding script at this time (The live demo and decoding script access is upon invitation only now ).
|
| 221 |
+
We are currently still working on a controlled decoding method to prevent this system from toxic generation. Please stay tuned.
|
| 222 |
+
|
| 223 |
+
**See issues [#3](https://github.com/microsoft/DialoGPT/issues/3) and [Reddit discussions](https://www.reddit.com/r/MachineLearning/comments/dt5woy/p_dialogpt_state_of_the_art_conversational_model/) for some discussions on third-party decoding methods.**
|
| 224 |
+
|
| 225 |
+
See below for some third-party decoding methods:
|
| 226 |
+
- [https://github.com/polakowo/gpt2bot](https://github.com/polakowo/gpt2bot) GPT2Bot implementation based on telegram by polakowo, [ref](https://github.com/microsoft/DialoGPT/issues/3#issuecomment-573904419)
|
| 227 |
+
- [https://colab.research.google.com/drive/1PslHE4Rl4RqSa20s7HEp0ZKITBir6ezE](https://colab.research.google.com/drive/1PslHE4Rl4RqSa20s7HEp0ZKITBir6ezE) A colab interactive notebook by qywu,[ref](https://github.com/microsoft/DialoGPT/issues/3#issuecomment-551410203)
|
| 228 |
+
- [https://github.com/andreamad8/DialoGPT2-Interact](https://github.com/andreamad8/DialoGPT2-Interact) An interactive script featuring multiturn chatbot by andreamad8,[ref](https://github.com/microsoft/DialoGPT/issues/3#issuecomment-551450016)
|
| 229 |
+
- [https://github.com/LHolten/DialoGTP-MMI-decoder](https://github.com/LHolten/DialoGTP-MMI-decoder) An MMI implementation by LHolten,[ref](https://github.com/microsoft/DialoGPT/issues/3#issuecomment-558318401)
|
| 230 |
+
- [https://colab.research.google.com/drive/1-_KjlAV3J1IVDw_9KogjKDCzgFY7Jp7E](https://colab.research.google.com/drive/1-_KjlAV3J1IVDw_9KogjKDCzgFY7Jp7E) A colab interactive notebook by illuminascent@Reddit,[ref](https://www.reddit.com/r/MachineLearning/comments/dt5woy/p_dialogpt_state_of_the_art_conversational_model/?st=k530k3oo&sh=f6cd20fd)
|
| 231 |
+
- [https://gradio.app/g/AK391/DialoGPT](https://gradio.app/g/AK391/DialoGPT) A 3rd party demo by [AK391](https://github.com/AK391) using Gradio [web demo](https://gradio.app/g/AK391/DialoGPT)
|
| 232 |
+
|
| 233 |
+
## Models
|
| 234 |
+
|
| 235 |
+
We release 6 fine-tuned models which can be further fine-tuned on low-resource user-customized dataset. The total parameters in these models range from 117M to 762M, in accord with OpenAI GPT-2 model sizes.
|
| 236 |
+
|
| 237 |
+
| Model | Fine-tuned from GPT-2| Trained from scratch
|
| 238 |
+
|----------------------|--------|--------|
|
| 239 |
+
| DialoGPT 762M model| [\[link\]](https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/large_ft.pkl) [\[huggingface model card\]](https://huggingface.co/microsoft/DialoGPT-large) | [\[link\]](https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/large_fs.pkl) |
|
| 240 |
+
| DialoGPT 345M model| [\[link\]](https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/medium_ft.pkl) [\[huggingface model card\]](https://huggingface.co/microsoft/DialoGPT-medium) | [\[link\]](https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/medium_fs.pkl) |
|
| 241 |
+
| DialoGPT 117M model| [\[link\]](https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/small_ft.pkl) [\[huggingface model card\]](https://huggingface.co/microsoft/DialoGPT-small)| [\[link\]](https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/small_fs.pkl) |
|
| 242 |
+
| DialoGPT 345M model (reverse, for MMI)| [link](https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/small_reverse.pkl) | -|
|
| 243 |
+
| [DialogRPT](https://github.com/golsun/DialogRPT) (**new** ranking models) | [link](https://github.com/golsun/DialogRPT) | -|
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
The model files can be loaded exactly as the GPT-2 model checkpoints from Huggingface's [Transformers](https://github.com/huggingface/transformers). You can find the corresponding configuration files (`merges.txt`, `config.json`, `vocab.json`) in DialoGPT's repo in `./configs/*`.
|
| 247 |
+
|
| 248 |
+
The reverse model is predicting the source from the target. This model is used for MMI reranking.
|
| 249 |
+
|
| 250 |
+
The [DialogRPT](https://github.com/golsun/DialogRPT) models our recently proposed ranking models used to predict the human feedback (upvotes, replies) of the responses. These models can be used to improve the DialoGPT generation quality (see our [EMNLP paper](https://arxiv.org/abs/2009.06978) for details).
|
| 251 |
+
|
| 252 |
+
## Retraining full models
|
| 253 |
+
|
| 254 |
+
### Data Preparation
|
| 255 |
+
|
| 256 |
+
The first step to retrain the full models is to generate the aforementioned 27GB Reddit dataset. This involves downloading full Reddit submission and comments dumps from [https://files.pushshift.io/reddit](https://files.pushshift.io/reddit) and creating intermediate files, which overall require **700GB of local disk space**. Downloading and processing the full data requires about 1-2 days, depending on your (CPU) compute capabilties (e.g., ~24 hours with 8 cores on a recent computer). Assuming you ran the above setup and installation steps (conda activate LSP, etc.), you can create the full dataset by running either:
|
| 257 |
+
|
| 258 |
+
```
|
| 259 |
+
python demo.py --data full
|
| 260 |
+
```
|
| 261 |
+
or
|
| 262 |
+
```
|
| 263 |
+
cd reddit_extractor; SIZE=full make -j 8; cd ..
|
| 264 |
+
```
|
| 265 |
+
|
| 266 |
+
The former command calls the latter, so the two methods are equivalent. We recommend the former, as the latter is mostly useful if you run into any problem or want to customize any arguments (e.g., the `make` command lets you build only a subset of the data). Note that the downloading phase can be error prone, for example based on your geolocation (firewall, etc.). If the above commands fail to generate `data/train.tsv`, or if that file is not anywhere close to 27GB, it means something went wrong. In that case, you may want to inspect `reddit_extractor/wget-log` and `reddit_extractor/logs/*.log` for any obvious error (e.g., wget unable to download from pushshift.io). If error messages don't make sense to you, feel free to contact us. If so, please be sure to include any error messages gathered from these log files.
|
| 267 |
+
|
| 268 |
+
Training data statistics: the generated training tsv file should be roughly 26.8 GB uncompressed, with 146.8M training instances, 3.87B source tokens, and 2.14B target tokens (including utterance-level 0/1 weights). The resulting train.tsv file should contain 146,846,215 lines.
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
### Training
|
| 272 |
+
|
| 273 |
+
We recommand generating the above data using the `demo.py --data full`, as it (1) generates the data, (2) converts it into DB format, and (3) trains a model using `python LSP_train.py`. Please directly edit `demo.py` if you want to customize any of the hyperparameters.
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
## Evaluations
|
| 277 |
+
|
| 278 |
+
#### DSTC-7 challenge
|
| 279 |
+
|
| 280 |
+
Our model achieved the state-of-the-art results in [DSTC-7 Challenge response generation task](https://github.com/mgalley/DSTC7-End-to-End-Conversation-Modeling).
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
| Experiment | NIST2 | NIST4 | BLEU2 | BLEU4 | METEOR | ENT-4 | DIST-1 | DIST-2 | Avg. Len |
|
| 284 |
+
|--------------------|-------|-------|--------|-------|--------|----------|------------|------------|---------|
|
| 285 |
+
| Human response | 2.62 | 2.65 | 12.35% | 3.13% | 8.31% | 10.45 | 16.66% | 67.01% | 18.8 |
|
| 286 |
+
| DSTC-7 Winner | 2.51 | 2.52 | 14.35% | 1.83% | 8.07% | 9.03 | 10.89% | 32.49% | 15.1 |
|
| 287 |
+
| DialoGPT 345M | 2.80 | 2.82 | 14.16% | 2.31% | 8.51% | **10.08** | 9.13% | 39.73% | 16.9 |
|
| 288 |
+
| DialoGPT 345M (BS) | **2.92** | **2.97** | **19.18%** | **6.05%** | **9.29%** | 9.57 | **15.73%** | **51.03%** | 14.2 |
|
| 289 |
+
|
| 290 |
+
where ENT represents the [Entropy score](https://arxiv.org/abs/1809.05972), and DIST represents the [Distinct score](https://arxiv.org/pdf/1510.03055.pdf). For all metrics except the average length, larger are better.
|
| 291 |
+
|
| 292 |
+
<!--| Experiment | NIST1 | NIST2 | NIST3 | NIST4 | BLEU1 | BLEU2 | BLEU3 | BLEU4 | METEOR | ENT-1 | ENT-2 | ENT-3 | ENT-4 | DIST-1 | DIST-2 | Len |
|
| 293 |
+
|----------------------|--------|--------|--------|--------|--------|--------|--------|--------|--------|----------|----------|----------|----------|------------|------------|---------|
|
| 294 |
+
| Human | 2.4237 | 2.6244 | 2.6472 | 2.65 | 0.3408 | 0.1235 | 0.0572 | 0.0313 | 0.0831 | 6.5893 | 9.7423 | 10.4101 | 10.4450 | 0.1666 | 0.6701 | 18.7568 |
|
| 295 |
+
| DSTC-7 Winner | 2.3408 | 2.5102 | 2.522 | 2.523 | 0.4122 | 0.1435 | 0.0501 | 0.0183 | 0.0807 | 5.3832 | 7.6065 | 8.5304 | 9.0298 | 0.1089 | 0.3249 | 15.1327 |
|
| 296 |
+
| DialoGPT | 2.5863 | 2.804 | 2.823 | 2.8246 | 0.3927 | 0.1416 | 0.0555 | 0.0231 | 0.0851 | 5.5791 | 8.5109 | 9.6872 | 10.0765 | 0.0913 | 0.3973 | 16.9484 |
|
| 297 |
+
| DialoGPT(beam search) | **2.5943**| **2.9163** | **2.9624** | **2.9681**| **0.4238** | **0.1918** | **0.1027** | **0.0605** | **0.0929** | **6.0815** | **8.7379** | 9.4037 | 9.5697 | 0.1573 | 0.5103 | 14.1603 |-->
|
| 298 |
+
|
| 299 |
+
Note that the superior automatic evaluation comparing to human responses does not necessary imply that our model achieves human parity. Please check out our paper for more detailed analysis.
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
To fine-tune the `345M` DialoGPT model on the DSTC-7 challenge data on a server with 8 V100 GPUs, please run the following commandline (The DSTC data can be found at [DSTC-7 repo](https://github.com/mgalley/DSTC7-End-to-End-Conversation-Modeling)):
|
| 303 |
+
|
| 304 |
+
```bash
|
| 305 |
+
python3 -m torch.distributed.launch --nproc_per_node=8 train_LSP.py --init_checkpoint ./models/medium/medium_ft.pkl --train_input_file ./data/DSTC_train.db --eval_input_file ./data/DSTC_valid.tsv --model_name_or_path ./model/medium/ --learning_rate 1e-4 --train_batch_size 64 --eval_batch_size 64 --no_token_id
|
| 306 |
+
```
|
| 307 |
+
|
| 308 |
+
The trained model can be found at [DSTC medium model](https://acvrpublicycchen.blob.core.windows.net/dialogpt/DSTC/medium_ft.pkl)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
#### Evaluation
|
| 312 |
+
|
| 313 |
+
1. Please **downloads** the following 3rd-party packages and save into the empty folder `3rdparty`:
|
| 314 |
+
* [**mteval-v14c.pl**](https://goo.gl/YUFajQ) to compute [NIST](http://www.mt-archive.info/HLT-2002-Doddington.pdf). You may need to install the following [perl](https://www.perl.org/get.html) modules (e.g. by `cpan install`): XML:Twig, Sort:Naturally and String:Util.
|
| 315 |
+
* [**meteor-1.5**](http://www.cs.cmu.edu/~alavie/METEOR/download/meteor-1.5.tar.gz) to compute [METEOR](http://www.cs.cmu.edu/~alavie/METEOR/index.html). It requires [Java](https://www.java.com/en/download/help/download_options.xml).
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
2. Please follow the [DSTC-7 official repo](https://github.com/mgalley/DSTC7-End-to-End-Conversation-Modeling/tree/master/data_extraction) to extract the data, and put `data-official-test/test.refs.txt` into `./dstc/data/` folder.
|
| 319 |
+
|
| 320 |
+
3. Run the extraction script below to produce the human response hypothesis file `human.resp.txt`:
|
| 321 |
+
|
| 322 |
+
```bash
|
| 323 |
+
python extract_human.py
|
| 324 |
+
```
|
| 325 |
+
|
| 326 |
+
4. Finally, to reproduce the results of human hypothesis on DSTC dataset, please run following commands under the repo folder:
|
| 327 |
+
|
| 328 |
+
```bash
|
| 329 |
+
python batch_eval.py
|
| 330 |
+
```
|
| 331 |
+
|
| 332 |
+
The evaluation results will be generated in the folder `./dstc/eval/`
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
## 6K multi-ref dataset result
|
| 336 |
+
|
| 337 |
+
### Automatic evaluation
|
| 338 |
+
|
| 339 |
+
We test on 6K multi-ref dataset from Reddit. The results are summarized in below
|
| 340 |
+
|
| 341 |
+
| Experiment | NIST2 | NIST4 | BLEU2 | BLEU4 | METEOR | ENT-4 | DIST-1 | DIST-2 | Avg. Len |
|
| 342 |
+
|--------------------|-------|-------|--------|-------|--------|----------|------------|------------|---------|
|
| 343 |
+
| Human response | 3.41 | 4.25 | 17.90% | 7.48% | 10.64% | 11 | 14.50% | 63.00% | 13.1 |
|
| 344 |
+
| DialoGPT 117M | 2.39 | 2.41 | 10.54% | 1.55% | 7.53% | 10.78 | 8.60% | 39.90% | 12.8 |
|
| 345 |
+
| DialoGPT 345M | 3 | 3.06 | 16.96% | 4.56% | 9.81% | 9.13 | 6.80% | 26.30% | 12.2 |
|
| 346 |
+
| DialoGPT 762M | 2.84 | 2.9 | 18.66% | 5.25% | 9.66% | 9.72 | 7.76% | 29.93% | 11.2 |
|
| 347 |
+
| DialoGPT 345M (BS) | **3.4** | **3.5** | **21.76%** | **7.92%** | 10.74% | 10.48 | **12.38%** | **48.74%** | 11.3 |
|
| 348 |
+
| DialoGPT 345M (w/MMI)| 3.28 | 3.33 | 15.68% | 3.94% | **11.23%** | **11.25** | 9.39% | 45.55% | 17.2 |
|
| 349 |
+
|
| 350 |
+
### <a name="human_eval"></a>Human evaluation
|
| 351 |
+
|
| 352 |
+
We further conduct human evaluations (6K examples for each methods, each example is evaluated by 3 human judges). The results show a strong evidence that our generation quality is towards approaching the quality of real human responses, under this non-interactive Turing test:
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
*Relevance*: A and B, which one is more relevant to the source prompt.
|
| 356 |
+
|
| 357 |
+
| System A | A Wins (%) | Ties (%) | B Wins (%) | System B|
|
| 358 |
+
|--------------------|-------|-------|--------|-------|
|
| 359 |
+
|DialoGPT 345M|2671 (45%) | 513 (9%) | 2816 (47%)| Human responses|
|
| 360 |
+
|DialoGPT 345M| 3281 (72%)| 394 (9%) | 882 (19%)| [PersonalityChat](https://docs.microsoft.com/en-us/azure/cognitive-services/project-personality-chat/overview)|
|
| 361 |
+
|DialoGPT 345M w/ MMI| **2871** (48%)| 522 (9%) | 2607 (43%)| Human responses|
|
| 362 |
+
|
| 363 |
+
*Informativeness*: A and B, which one is more contentful and informative.
|
| 364 |
+
|
| 365 |
+
| System A | A Wins (%) | Ties (%) | B Wins (%) | System B|
|
| 366 |
+
|--------------------|-------|-------|--------|-------|
|
| 367 |
+
|DialoGPT 345M| 2722 (45%) | 234 (4%) | 3044 (51%)| Human responses|
|
| 368 |
+
|DialoGPT 345M|3490 (77%) | 206 (5%) | 861 (19%)| [PersonalityChat](https://docs.microsoft.com/en-us/azure/cognitive-services/project-personality-chat/overview)|
|
| 369 |
+
|DialoGPT 345M w/ MMI| **3011** (50%)| 234 (4%) | 2755 (46%)| Human responses|
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
*Human-Like*: A and B, which one do you think is more likely to be generated by Human.
|
| 373 |
+
|
| 374 |
+
| System A | A Wins (%) | Ties (%) | B Wins (%) | System B|
|
| 375 |
+
|--------------------|-------|-------|--------|-------|
|
| 376 |
+
|DialoGPT 345M|2716 (45%) | 263 (4%) | 3021 (50%)| Human responses|
|
| 377 |
+
|DialoGPT 345M|3462 (76%) | 196 (4%) | 899 (20%)| [PersonalityChat](https://docs.microsoft.com/en-us/azure/cognitive-services/project-personality-chat/overview)|
|
| 378 |
+
|DialoGPT 345M w/ MMI| **2978** (50%)| 241 (4%) | 2781 (46%)| Human responses|
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
Please see full details in our [arxiv paper](https://arxiv.org/abs/1911.00536).
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
<!--Relevance
|
| 387 |
+
System Wins (%) Ties (%) Losses (%)
|
| 388 |
+
2 vs 1 2671 (0.45) 513 (0.09) 2816 (0.47)
|
| 389 |
+
2 vs 3 3281 (0.72) 394 (0.09) 882 (0.19)
|
| 390 |
+
2 vs 4 2379 (0.40) 527 (0.09) 3094 (0.52)
|
| 391 |
+
2 vs 5 3019 (0.50) 581 (0.10) 2400 (0.40)
|
| 392 |
+
2 vs 6 2726 (0.45) 576 (0.10) 2698 (0.45)
|
| 393 |
+
|
| 394 |
+
Informativeness
|
| 395 |
+
System Wins (%) Ties (%) Losses (%)
|
| 396 |
+
2 vs 1 2722 (0.45) 234 (0.04) 3044 (0.51)
|
| 397 |
+
2 vs 3 3490 (0.77) 206 (0.05) 861 (0.19)
|
| 398 |
+
2 vs 4 2474 (0.41) 257 (0.04) 3269 (0.54)
|
| 399 |
+
2 vs 5 3230 (0.54) 362 (0.06) 2408 (0.40)
|
| 400 |
+
2 vs 6 2856 (0.48) 303 (0.05) 2841 (0.47)
|
| 401 |
+
|
| 402 |
+
Human-Like
|
| 403 |
+
System Wins (%) Ties (%) Losses (%)
|
| 404 |
+
2 vs 1 2716 (0.45) 263 (0.04) 3021 (0.50)
|
| 405 |
+
2 vs 3 3462 (0.76) 196 (0.04) 899 (0.20)
|
| 406 |
+
2 vs 4 2478 (0.41) 289 (0.05) 3233 (0.54)
|
| 407 |
+
2 vs 5 3233 (0.54) 340 (0.06) 2427 (0.40)
|
| 408 |
+
2 vs 6 2847 (0.47) 321 (0.05) 2832 (0.47)
|
| 409 |
+
-->
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
<!--| Experiment | NIST1 | NIST2 | NIST3 | NIST4 | BLEU1 | BLEU2 | BLEU3 | BLEU4 | METEOR | ENT-4 | DIST-1 | DIST-2 |
|
| 413 |
+
|------------------------------|-------|-------|-------|-------|--------|--------|--------|-------|--------|----------|------------|------------|
|
| 414 |
+
| Human response | 2.99 | 3.41 | 3.83 | 4.25 | 39.61% | 17.90% | 10.71% | 7.48% | 10.64% | 11 | 14.50% | 63.00% |
|
| 415 |
+
| DialoGPT 117M | 2.25 | 2.39 | 2.41 | 2.41 | 35.43% | 10.54% | 3.85% | 1.55% | 7.53% | 10.78 | 8.60% | 39.90% |
|
| 416 |
+
| DialoGPT 345M |**2.67** | **3** | **3.05** | **3.06** | 40.97% | 16.96% | 8.31% | 4.56% | **9.81%** | 9.13 | 6.80% | 26.30% |
|
| 417 |
+
| DialoGPT 762M | 2.51 | 2.84 | 2.89 | 2.9 | **44.45%** | **18.66%** | **9.32%** | **5.25%** | 9.66% | **9.72** | **7.76%** | **29.93%** |-->
|
| 418 |
+
|
| 419 |
+
<!--| DialoGPT 762M (beam search)| 2.51 | 2.9 | 2.97 | 2.98 | 44.84% | 21.08% | 12.01% | 7.57% | 10.11% | 10.06 | 11.62% | 44.07% |-->
|
| 420 |
+
<!--## ConvAI challenge -->
|
| 421 |
+
<!--| DialoGPT 345M (beam search) | 2.94 | 3.4 | 3.49 | 3.5 | 45.27% | 21.76% | 12.51% | 7.92% | 10.74% | 10.48 | 12.38% | 48.74% |-->
|
| 422 |
+
## Generated Examples
|
| 423 |
* Multi-turn generation examples from an interactive environment:
|
| 424 |
|
| 425 |
|Role | Response |
|
|
|
|
| 431 |
|User |This is so difficult ! |
|
| 432 |
| Bot | You have no idea how hard it is to be a millionaire and happy . There is a reason the rich have a lot of money |
|
| 433 |
|
| 434 |
+
* Multi-turn self-playing dialogue with user prompt
|
| 435 |
+
|
| 436 |
+
|Role | Response |
|
| 437 |
+
|---------|--------|
|
| 438 |
+
|User | what is the meaning of life ?|
|
| 439 |
+
| Bot | The meaning is to be with your family and friends .|
|
| 440 |
+
|Bot | I’m going to guess : It means that your parents and friends have loved you . That’s the meaning , I believe . |
|
| 441 |
+
| Bot | I think the meaning is to make sure you’re alive when you die . So you’ll be with your family , friends , and yourself .|
|
| 442 |
+
|Bot |It’s not , but thanks for trying to help . |
|
| 443 |
|
|
|
|
| 444 |
|
|
|
|
| 445 |
|
|
|
|
| 446 |
|
|
|
|
|
|
|
|
|
|
| 447 |
|
| 448 |
+
* Commonsense question answering examples
|
| 449 |
|
| 450 |
+
|User | Bot |
|
| 451 |
+
|---------|--------|
|
| 452 |
+
|who is the first president of the United States | George Washington |
|
| 453 |
+
|what is the boiling point of water? | I think it's about 212 F. |
|
| 454 |
+
|which one is bigger, sun or moon? | The sun . |
|
| 455 |
+
|which animal has black and white stripes? | A black and white striped cat . |
|
| 456 |
+
|
| 457 |
+
* Handling context that does not appear in training examples (note that training examples are sampled from 2005 to 2017)
|
| 458 |
+
|
| 459 |
+
|User | Bot |
|
| 460 |
+
|---------|--------|
|
| 461 |
+
|The trading war between China and US is still happening . | I'm sure the Chinese government and the U . S . will be fine with that as long as they keep their arms on the table and the money flowing . |
|
| 462 |
+
|who won the world cup in 2018 ? | Germany , I believe . |
|
| 463 |
+
|Nvidia's Titan RTX is really good . | I have this card and i love it . No complaints . |
|
| 464 |
+
|Can Porsche beat Tesla with its new Taycan EV ? | I don't think so . The Tesla is more powerful than the porsche .
|
| 465 |
+
|
| 466 |
+
* Our bot's response to philosophy questions
|
| 467 |
+
|
| 468 |
+
|User | Bot |
|
| 469 |
+
|---------|--------|
|
| 470 |
+
|What is the meaning of a good life ? | I think it's that we're all connected to our past lives and the meaning of life is to live the way you want to and how you want to . |
|
| 471 |
+
|Which is more important -- being successful or being happy ? | Being successful is just endless . Being happy is the only thing that matters . |
|
| 472 |
+
|Is peace the only way to stop war ? | Not if the world is not willing to work with you . Peace is the ultimate goal , but peace requires more than that . |
|
| 473 |
+
| How much freedom should people have ? | The right amount of freedom is the freedom to do as you please , as long as you don't hurt people or property .|
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
* Finally
|
| 477 |
+
|
| 478 |
+
|User | Bot |
|
| 479 |
+
|---------|--------|
|
| 480 |
+
|What do you think of the DialoGPT repo ? | It's pretty good , but it's very easy to find a bug . |
|
| 481 |
+
|
| 482 |
+
Please start a issue if you spot any :)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
## Try our system
|
| 487 |
+
The live demo and decoding script access is upon invitation only now. Please stayed tuned for the full release.
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
<!--The live demo and decoding script access is upon approval request. Please apply [here](https://docs.google.com/spreadsheets/d/1epTNfaqva1isVO_o9pbyhVLsnzDn58dGkcLB0OUVcqs/edit?usp=sharing)-->
|
| 491 |
+
|
| 492 |
+
<!--This model should give a Hits@1 over 79, perplexity of 20.5 and F1 of 16.5 using the convai2 evaluation script (see below).
|
| 493 |
+
|
| 494 |
+
These numbers are slightly lower than the number we obtained in the ConvAI2 competition. Here is what you can tweak to reach the same results:
|
| 495 |
+
|
| 496 |
+
- in the ConvAI2 competition we also used tweaked position emebddings so that the history of the dialog always start at with the same embeddings. This is easy to add with pytorch-pretrained-bert and should improve the hits@1 metric.
|
| 497 |
+
- in the ConvAI2 competition we used a beam search decoder. While the results are better in term of f1 metric, our feeling is that the human experience is les compelling with beam search versus the nucleus sampling detector which is provided in the present repository.-->
|
| 498 |
+
|
| 499 |
+
<!--## Using the interaction script
|
| 500 |
+
|
| 501 |
+
The training script saves all the experiments and checkpoints in a sub-folder named with the timestamp of the experiment in the `./runs` folder of the repository base folder.
|
| 502 |
+
|
| 503 |
+
You can then use the interactive script to interact with the model simply by pointing to this folder.
|
| 504 |
+
|
| 505 |
+
Here is an example command line to run the interactive script:
|
| 506 |
+
|
| 507 |
+
```bash
|
| 508 |
+
python ./interact.py --model_checkpoint ./data/Apr17_13-31-38_thunder/ # run the interactive script with a training checkpoint
|
| 509 |
+
python ./interact.py # run the interactive script with the finetuned model on our S3
|
| 510 |
+
```
|
| 511 |
+
|
| 512 |
+
The fine-tuned model will gives FINAL Hits@1: 0.715
|
| 513 |
|
| 514 |
+
The interactive script accept a few arguments to tweak the decoding algorithm:
|
|
|
|
|
|
|
|
|
|
| 515 |
|
| 516 |
+
Argument | Type | Default value | Description
|
| 517 |
+
---------|------|---------------|------------
|
| 518 |
+
dataset_path | `str` | `""` | Path or url of the dataset. If empty download from S3.
|
| 519 |
+
dataset_cache | `str` | `'./dataset_cache.bin'` | Path or url of the dataset cache
|
| 520 |
+
model | `str` | `"openai-gpt"` | Path, url or short name of the model
|
| 521 |
+
max_history | `int` | `2` | Number of previous utterances to keep in history
|
| 522 |
+
device | `str` | `cuda` if `torch.cuda.is_available()` else `cpu` | Device (cuda or cpu)
|
| 523 |
+
no_sample | action `store_true` | Set to use greedy decoding instead of sampling
|
| 524 |
+
max_length | `int` | `20` | Maximum length of the output utterances
|
| 525 |
+
min_length | `int` | `1` | Minimum length of the output utterances
|
| 526 |
+
seed | `int` | `42` | Seed
|
| 527 |
+
temperature | `int` | `0.7` | Sampling softmax temperature
|
| 528 |
+
top_k | `int` | `0` | Filter top-k tokens before sampling (`<=0`: no filtering)
|
| 529 |
+
top_p | `float` | `0.9` | Nucleus filtering (top-p) before sampling (`<=0.0`: no filtering)
|
| 530 |
|
| 531 |
+
## Running ConvAI2 evaluation scripts
|
|
|
|
| 532 |
|
| 533 |
+
To run the evaluation scripts of the ConvAI2 challenge, you first need to install `ParlAI` in the repo base folder like this:
|
| 534 |
+
|
| 535 |
+
```bash
|
| 536 |
+
git clone https://github.com/facebookresearch/ParlAI.git
|
| 537 |
+
cd ParlAI
|
| 538 |
+
python setup.py develop
|
| 539 |
+
```
|
| 540 |
+
|
| 541 |
+
You can then run the evaluation script from `ParlAI` base folder:
|
| 542 |
+
|
| 543 |
+
```bash
|
| 544 |
+
cd ParlAI
|
| 545 |
+
python ../convai_evaluation.py --eval_type hits@1 # to download and evaluate our fine-tuned model on hits@1 metric
|
| 546 |
+
python ../convai_evaluation.py --eval_type hits@1 --model_checkpoint ./data/Apr17_13-31-38_thunder/ # to evaluate a training checkpoint on hits@1 metric
|
| 547 |
```
|
| 548 |
+
|
| 549 |
+
The evaluation script accept a few arguments to select the evaluation metric and tweak the decoding algorithm:
|
| 550 |
+
|
| 551 |
+
Argument | Type | Default value | Description
|
| 552 |
+
---------|------|---------------|------------
|
| 553 |
+
eval_type | `str` | `"hits@1"` | Evaluate the model on `hits@1`, `ppl` or `f1` metric on the ConvAI2 validation dataset
|
| 554 |
+
model | `str` | `"openai-gpt"` | Path, url or short name of the model
|
| 555 |
+
max_history | `int` | `2` | Number of previous utterances to keep in history
|
| 556 |
+
device | `str` | `cuda` if `torch.cuda.is_available()` else `cpu` | Device (cuda or cpu)
|
| 557 |
+
no_sample | action `store_true` | Set to use greedy decoding instead of sampling
|
| 558 |
+
max_length | `int` | `20` | Maximum length of the output utterances
|
| 559 |
+
min_length | `int` | `1` | Minimum length of the output utterances
|
| 560 |
+
seed | `int` | `42` | Seed
|
| 561 |
+
temperature | `int` | `0.7` | Sampling softmax temperature
|
| 562 |
+
top_k | `int` | `0` | Filter top-k tokens before sampling (`<=0`: no filtering)
|
| 563 |
+
top_p | `float` | `0.9` | Nucleus filtering (top-p) before sampling (`<=0.0`: no filtering)
|
| 564 |
+
|
| 565 |
+
-->
|
| 566 |
+
|
| 567 |
+
## Related Project
|
| 568 |
+
|
| 569 |
+
* RetGen: [https://github.com/dreasysnail/RetGen](https://github.com/dreasysnail/RetGen). Retrieval-augmented/grounded DialoGPT and beyond. RetGen is a joint training framework that simultaneously optimizes a dense passage retriever and a knowledge-grounded text generator in an end-to-end fashion.
|
| 570 |
+
|
| 571 |
+
* Microsoft ICECAPS: [https://github.com/microsoft/icecaps](https://github.com/microsoft/icecaps).
|
| 572 |
+
|
| 573 |
+
As an orthogonal repository of this project,
|
| 574 |
+
Microsoft Icecaps is an open-source toolkit (in tensorflow) for building neural conversational systems. Icecaps provides an array of tools from recent conversation modeling and general NLP literature within a flexible paradigm that enables complex multi-task learning setups.
|
| 575 |
+
|
| 576 |
+
* Pretrained UniLM: [https://github.com/microsoft/unilm](https://github.com/microsoft/unilm)
|
| 577 |
+
* MT-DNN: [https://github.com/namisan/mt-dnn](https://github.com/namisan/mt-dnn)
|
| 578 |
+
* A chinese counterpart of DialoGPT by yangjianxin1. [https://github.com/yangjianxin1/GPT2-chitchat](https://github.com/yangjianxin1/GPT2-chitchat). We are glad to see that the MMI strategy that we used in DialoGPT has also improved the performance for this project as well!
|
| 579 |
+
|
| 580 |
+
## Contact
|
| 581 |
+
|
| 582 |
+
Please contact [[email protected]](mailto:[email protected]) if you have any questions/suggestions. However, the response will be sporadic. Please expect delay.
|
| 583 |
+
|
| 584 |
+
## Contributing
|
| 585 |
+
|
| 586 |
+
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
| 587 |
+
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
| 588 |
+
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
| 589 |
+
|
| 590 |
+
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
| 591 |
+
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
| 592 |
+
provided by the bot. You will only need to do this once across all repos using our CLA.
|
| 593 |
+
|
| 594 |
+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
| 595 |
+
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
| 596 |
+
contact [[email protected]](mailto:[email protected]) with any additional questions or comments.
|
| 597 |
+
|
| 598 |
+
## Disclaimer
|
| 599 |
+
|
| 600 |
+
This repository aims to facilitate research in large-scale pretraining for conversational data. This toolkit contains only part of the modeling machinery needed to actually produce a model weight file in a running dialog. On its own, this model provides only information about the weights of various text spans; in order for a researcher to actually use it, they will need to bring conversational data of their own and decode the response generation from the pretrained system. Microsoft is not responsible for any generation from the 3rd party utilization of the pretrained system.
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
## Citation
|
| 605 |
+
If you use this code in your research, you can cite our [arxiv paper](https://arxiv.org/abs/1911.00536):
|
| 606 |
+
```bash
|
| 607 |
+
@inproceedings{zhang2019dialogpt,
|
| 608 |
+
title={DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation},
|
| 609 |
+
author={Yizhe Zhang and Siqi Sun and Michel Galley and Yen-Chun Chen and Chris Brockett and Xiang Gao and Jianfeng Gao and Jingjing Liu and Bill Dolan},
|
| 610 |
+
year={2020},
|
| 611 |
+
booktitle={ACL, system demonstration}
|
| 612 |
+
}
|
| 613 |
+
```
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
|
SECURITY.md
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.1 BLOCK -->
|
| 2 |
+
|
| 3 |
+
## Security
|
| 4 |
+
|
| 5 |
+
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [many more](https://opensource.microsoft.com/).
|
| 6 |
+
|
| 7 |
+
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets Microsoft's [definition](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)) of a security vulnerability, please report it to us as described below.
|
| 8 |
+
|
| 9 |
+
## Reporting Security Issues
|
| 10 |
+
|
| 11 |
+
**Please do not report security vulnerabilities through public GitHub issues.** Instead, please report them to the Microsoft Security Response Center at [[email protected]](mailto:[email protected]). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://technet.microsoft.com/en-us/security/dn606155).
|
| 12 |
+
|
| 13 |
+
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
|
| 14 |
+
|
| 15 |
+
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
| 16 |
+
|
| 17 |
+
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
| 18 |
+
* Full paths of source file(s) related to the manifestation of the issue
|
| 19 |
+
* The location of the affected source code (tag/branch/commit or direct URL)
|
| 20 |
+
* Any special configuration required to reproduce the issue
|
| 21 |
+
* Step-by-step instructions to reproduce the issue
|
| 22 |
+
* Proof-of-concept or exploit code (if possible)
|
| 23 |
+
* Impact of the issue, including how an attacker might exploit the issue
|
| 24 |
+
|
| 25 |
+
This information will help us triage your report more quickly.
|
| 26 |
+
|
| 27 |
+
## Preferred Languages
|
| 28 |
+
|
| 29 |
+
We prefer all communications to be in English.
|
| 30 |
+
|
| 31 |
+
## Policy
|
| 32 |
+
|
| 33 |
+
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
|
| 34 |
+
|
| 35 |
+
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
config.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation_function": "gelu_new",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"GPT2LMHeadModel"
|
| 5 |
+
],
|
| 6 |
+
"attn_pdrop": 0.1,
|
| 7 |
+
"bos_token_id": 50256,
|
| 8 |
+
"embd_pdrop": 0.1,
|
| 9 |
+
"eos_token_id": 50256,
|
| 10 |
+
"initializer_range": 0.02,
|
| 11 |
+
"layer_norm_epsilon": 1e-05,
|
| 12 |
+
"model_type": "gpt2",
|
| 13 |
+
"n_ctx": 1024,
|
| 14 |
+
"n_embd": 1024,
|
| 15 |
+
"n_head": 16,
|
| 16 |
+
"n_layer": 24,
|
| 17 |
+
"n_positions": 1024,
|
| 18 |
+
"resid_pdrop": 0.1,
|
| 19 |
+
"summary_activation": null,
|
| 20 |
+
"summary_first_dropout": 0.1,
|
| 21 |
+
"summary_proj_to_labels": true,
|
| 22 |
+
"summary_type": "cls_index",
|
| 23 |
+
"summary_use_proj": true,
|
| 24 |
+
"task_specific_params": {
|
| 25 |
+
"conversational": {
|
| 26 |
+
"max_length": 1000
|
| 27 |
+
}
|
| 28 |
+
},
|
| 29 |
+
"vocab_size": 50257
|
| 30 |
+
}
|
data_config.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
import os
|
| 4 |
+
from . import proj_env
|
| 5 |
+
|
| 6 |
+
RAW_DATA_DIR = os.path.join(proj_env.ROOT_DIR, "raw_data")
|
| 7 |
+
PROCESSED_DATA_DIR = os.path.join(proj_env.ROOT_DIR, "processed")
|
| 8 |
+
PIPELINE_DATA_DIR = os.path.join(proj_env.ROOT_DIR, "pipeline_data")
|
| 9 |
+
|
| 10 |
+
TEST_KEY_FN = os.path.join(RAW_DATA_DIR, "keys.2k.txt")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
MAX_LEN = 128 #512
|
| 14 |
+
MAX_CONTEXT_LEN = 64#250
|
| 15 |
+
|
| 16 |
+
TAG_LIST = ["<p>", "<title>", "<anchor>"] + ["<h%s>" % i for i in range(1, 7)]
|
| 17 |
+
|
data_loader.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
import gzip
|
| 4 |
+
import json
|
| 5 |
+
import math
|
| 6 |
+
import random
|
| 7 |
+
import shelve
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
import subprocess as sp
|
| 11 |
+
|
| 12 |
+
from math import ceil
|
| 13 |
+
from torch.utils.data import DataLoader, Sampler, Dataset
|
| 14 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 15 |
+
|
| 16 |
+
from env import END_OF_TEXT_TOKEN
|
| 17 |
+
from gpt2_training.train_utils import (InputFeatures, InputFeatures_train,
|
| 18 |
+
RedditExample)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class BucketSampler(Sampler):
|
| 22 |
+
"""
|
| 23 |
+
this sampler will sort data by sequence length
|
| 24 |
+
"""
|
| 25 |
+
def __init__(self, lens, bucket_size, batch_size,
|
| 26 |
+
droplast=False, shuffle=True):
|
| 27 |
+
self._lens = lens
|
| 28 |
+
self._batch_size = batch_size
|
| 29 |
+
self._bucket_size = bucket_size
|
| 30 |
+
self._droplast = droplast
|
| 31 |
+
self._shuf = shuffle
|
| 32 |
+
|
| 33 |
+
def __iter__(self):
|
| 34 |
+
ids = list(range(len(self._lens)))
|
| 35 |
+
if self._shuf:
|
| 36 |
+
random.shuffle(ids)
|
| 37 |
+
buckets = [sorted(ids[i:i+self._bucket_size],
|
| 38 |
+
key=lambda i: self._lens[i], reverse=True)
|
| 39 |
+
for i in range(0, len(ids), self._bucket_size)]
|
| 40 |
+
batches = [bucket[i:i+self._batch_size]
|
| 41 |
+
for bucket in buckets
|
| 42 |
+
for i in range(0, len(bucket), self._batch_size)]
|
| 43 |
+
if self._droplast:
|
| 44 |
+
batches = [batch for batch in batches
|
| 45 |
+
if len(batch) == self._batch_size]
|
| 46 |
+
if self._shuf:
|
| 47 |
+
random.shuffle(batches)
|
| 48 |
+
return iter(batches)
|
| 49 |
+
|
| 50 |
+
def __len__(self):
|
| 51 |
+
bucket_sizes = ([self._bucket_size]
|
| 52 |
+
* (len(self._lens) // self._bucket_size)
|
| 53 |
+
+ [len(self._lens) % self._bucket_size])
|
| 54 |
+
if self._droplast:
|
| 55 |
+
return sum(s//self._batch_size for s in bucket_sizes)
|
| 56 |
+
else:
|
| 57 |
+
return sum(math.ceil(s/self._batch_size) for s in bucket_sizes)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class GPT2FeatureDataset(Dataset):
|
| 61 |
+
""" pytorch dataset for GPT2 training """
|
| 62 |
+
def __init__(self, features, max_len=None):
|
| 63 |
+
self.features = features
|
| 64 |
+
self.max_len = max_len # this max_len do truncate
|
| 65 |
+
|
| 66 |
+
def __getitem__(self, i):
|
| 67 |
+
feat_dict = self.features[i]
|
| 68 |
+
if self.max_len is not None and feat_dict['input_len'] > self.max_len:
|
| 69 |
+
# tuncate on the left side (context)
|
| 70 |
+
feat_dict['input_ids'] = feat_dict['input_ids'][-self.max_len:]
|
| 71 |
+
feat_dict['position_ids'] = feat_dict['position_ids'][
|
| 72 |
+
-self.max_len:]
|
| 73 |
+
feat_dict['token_type_ids'] = feat_dict['token_type_ids'][
|
| 74 |
+
-self.max_len:]
|
| 75 |
+
feat_dict['lm_labels'] = feat_dict['lm_labels'][-self.max_len:]
|
| 76 |
+
try:
|
| 77 |
+
for s in ['context_len', 'response_len']:
|
| 78 |
+
if s in feat_dict.keys():
|
| 79 |
+
print("db file missing "+s)
|
| 80 |
+
del feat_dict[s]
|
| 81 |
+
except Exception:
|
| 82 |
+
import pdb
|
| 83 |
+
pdb.set_trace()
|
| 84 |
+
|
| 85 |
+
feat = InputFeatures_train(**feat_dict)
|
| 86 |
+
return feat
|
| 87 |
+
|
| 88 |
+
def __len__(self):
|
| 89 |
+
return len(self.features)
|
| 90 |
+
|
| 91 |
+
@staticmethod
|
| 92 |
+
def collate(features):
|
| 93 |
+
input_ids = pad_sequence([torch.tensor(f.input_ids, dtype=torch.long)
|
| 94 |
+
for f in features],
|
| 95 |
+
batch_first=True, padding_value=0)
|
| 96 |
+
position_ids = pad_sequence([torch.tensor(f.position_ids,
|
| 97 |
+
dtype=torch.long)
|
| 98 |
+
for f in features],
|
| 99 |
+
batch_first=True, padding_value=0)
|
| 100 |
+
token_type_ids = pad_sequence([torch.tensor(f.token_type_ids,
|
| 101 |
+
dtype=torch.long)
|
| 102 |
+
for f in features],
|
| 103 |
+
batch_first=True, padding_value=0)
|
| 104 |
+
labels = pad_sequence([torch.tensor(f.lm_labels, dtype=torch.long)
|
| 105 |
+
for f in features],
|
| 106 |
+
batch_first=True, padding_value=-1)
|
| 107 |
+
return (input_ids, position_ids, token_type_ids, labels)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class BucketingDataLoader(object):
|
| 111 |
+
""" this loads shelve db chunks and then convert to mini-batch loader"""
|
| 112 |
+
def __init__(self, db_name, batch_size, max_seq_len,
|
| 113 |
+
bucket=100, shuffle=True):
|
| 114 |
+
self.db = shelve.open(f'{db_name}/db', 'r')
|
| 115 |
+
self.batch_size = batch_size
|
| 116 |
+
self.max_len = max_seq_len
|
| 117 |
+
self.bucket_size = bucket * batch_size
|
| 118 |
+
self.shuffle = shuffle
|
| 119 |
+
|
| 120 |
+
def _get_keys(self):
|
| 121 |
+
keys = list(self.db.keys())
|
| 122 |
+
return keys
|
| 123 |
+
|
| 124 |
+
def __iter__(self):
|
| 125 |
+
keys = self._get_keys()
|
| 126 |
+
if self.shuffle:
|
| 127 |
+
random.shuffle(keys)
|
| 128 |
+
for key in keys:
|
| 129 |
+
chunk = json.loads(gzip.decompress(self.db[key]).decode('utf-8'))
|
| 130 |
+
# discard long examples
|
| 131 |
+
trunc_chunk = []
|
| 132 |
+
lens = []
|
| 133 |
+
for feat in chunk:
|
| 134 |
+
if feat['input_len'] > self.max_len:
|
| 135 |
+
continue
|
| 136 |
+
trunc_chunk.append(feat)
|
| 137 |
+
lens.append(feat['input_len'])
|
| 138 |
+
|
| 139 |
+
dataset = GPT2FeatureDataset(trunc_chunk, self.max_len)
|
| 140 |
+
sampler = BucketSampler(lens, self.bucket_size, self.batch_size,
|
| 141 |
+
droplast=True, shuffle=self.shuffle)
|
| 142 |
+
loader = DataLoader(dataset, batch_sampler=sampler,
|
| 143 |
+
num_workers=0, # can test multi-worker
|
| 144 |
+
collate_fn=GPT2FeatureDataset.collate)
|
| 145 |
+
yield from loader
|
| 146 |
+
|
| 147 |
+
def __len__(self):
|
| 148 |
+
raise NotImplementedError()
|
| 149 |
+
|
| 150 |
+
def __del__(self):
|
| 151 |
+
self.db.close()
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class DistributedBucketingDataLoader(BucketingDataLoader):
|
| 155 |
+
""" distributed version """
|
| 156 |
+
def __init__(self, rank, num_replica, *args, **kwargs):
|
| 157 |
+
super().__init__(*args, **kwargs)
|
| 158 |
+
self.rank = rank
|
| 159 |
+
self.num_replica = num_replica
|
| 160 |
+
|
| 161 |
+
def _get_keys(self):
|
| 162 |
+
keys = list(self.db.keys())[self.rank::self.num_replica]
|
| 163 |
+
return keys
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def convert_examples_to_features_dynamic(examples, tokenizer,
|
| 167 |
+
max_seq_length=512):
|
| 168 |
+
"""
|
| 169 |
+
do not pad
|
| 170 |
+
"""
|
| 171 |
+
def featurize(example):
|
| 172 |
+
conv_id = example.conv_id
|
| 173 |
+
context_id = tokenizer.encode(example.context)
|
| 174 |
+
end_of_text_id = tokenizer.encoder[END_OF_TEXT_TOKEN]
|
| 175 |
+
|
| 176 |
+
# response is provided in example
|
| 177 |
+
response_id = tokenizer.encode(example.response)
|
| 178 |
+
|
| 179 |
+
input_ids_len = len(context_id) + len(response_id) + 2
|
| 180 |
+
if input_ids_len > max_seq_length:
|
| 181 |
+
if len(context_id) > input_ids_len - max_seq_length:
|
| 182 |
+
# cut context from beginning if length of context + response is too long
|
| 183 |
+
# and len of context is long enough to cut
|
| 184 |
+
context_id = context_id[input_ids_len - max_seq_length:]
|
| 185 |
+
else:
|
| 186 |
+
# cut response from end if length of context + response is too long
|
| 187 |
+
# and len of response is long enough to cut
|
| 188 |
+
# if no response is available, discard the data
|
| 189 |
+
if max_seq_length-len(context_id)-2 < 0:
|
| 190 |
+
return None
|
| 191 |
+
response_id = response_id[:max_seq_length-len(context_id)-2]
|
| 192 |
+
|
| 193 |
+
input_ids = context_id + [end_of_text_id] + response_id + [end_of_text_id]
|
| 194 |
+
|
| 195 |
+
# label simplely is next token in sequences. MASK all context_id tokens except for the last one
|
| 196 |
+
lm_labels = [-1] * len(context_id) + response_id + [end_of_text_id] + [-1]
|
| 197 |
+
|
| 198 |
+
position_ids = list(range(len(input_ids)))
|
| 199 |
+
|
| 200 |
+
token_type_id = [0] * len(input_ids)
|
| 201 |
+
|
| 202 |
+
return InputFeatures(conv_id, input_ids, position_ids, token_type_id,
|
| 203 |
+
lm_labels, len(context_id), len(response_id))
|
| 204 |
+
|
| 205 |
+
# discard None feature
|
| 206 |
+
features = [f for f in [featurize(ex) for ex in examples] if f is not None]
|
| 207 |
+
return features
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class DynamicBatchingLoader(object):
|
| 211 |
+
""" this loader takes raw text file, used for validate perplexity """
|
| 212 |
+
def __init__(self, corpus_file, tokenizer, normalize_data,
|
| 213 |
+
batch_size, max_seq_length):
|
| 214 |
+
self.corpus = corpus_file
|
| 215 |
+
self.toker = tokenizer
|
| 216 |
+
self.norm = normalize_data
|
| 217 |
+
self.bs = batch_size
|
| 218 |
+
self.max_seq_length = max_seq_length
|
| 219 |
+
self.num_examples = self.get_len(corpus_file)
|
| 220 |
+
|
| 221 |
+
def __iter__(self, epoch=1):
|
| 222 |
+
if epoch > 0:
|
| 223 |
+
for epoch in range(epoch):
|
| 224 |
+
yield from self._iter_epoch()
|
| 225 |
+
else:
|
| 226 |
+
while True:
|
| 227 |
+
yield from self._iter_epoch()
|
| 228 |
+
|
| 229 |
+
def __len__(self):
|
| 230 |
+
return ceil(self.num_examples/self.bs)
|
| 231 |
+
|
| 232 |
+
def _iter_epoch(self):
|
| 233 |
+
try:
|
| 234 |
+
with open(self.corpus, 'r', encoding="utf-8") as corpus:
|
| 235 |
+
i = 0
|
| 236 |
+
while True:
|
| 237 |
+
examples = []
|
| 238 |
+
cur_bs = 0
|
| 239 |
+
while True:
|
| 240 |
+
line = next(corpus).encode('utf-8').decode('utf-8')
|
| 241 |
+
contents = line.split('\t')
|
| 242 |
+
src, tgt_all = contents[0], contents[1:]
|
| 243 |
+
for tgt in tgt_all:
|
| 244 |
+
if self.norm:
|
| 245 |
+
src_line = ' '.join(src.strip().split())
|
| 246 |
+
tgt_line = ' '.join(tgt.strip().split())
|
| 247 |
+
else:
|
| 248 |
+
src_line = src.strip()
|
| 249 |
+
tgt_line = tgt.strip()
|
| 250 |
+
examples.append(
|
| 251 |
+
RedditExample(i, src_line, tgt_line),
|
| 252 |
+
)
|
| 253 |
+
i += 1
|
| 254 |
+
cur_bs += 1
|
| 255 |
+
if cur_bs >= self.bs:
|
| 256 |
+
break
|
| 257 |
+
features = convert_examples_to_features_dynamic(
|
| 258 |
+
examples, self.toker, self.max_seq_length)
|
| 259 |
+
batch = self._batch_feature(features)
|
| 260 |
+
yield batch
|
| 261 |
+
except StopIteration:
|
| 262 |
+
pass
|
| 263 |
+
|
| 264 |
+
def _batch_feature(self, features):
|
| 265 |
+
input_ids = pad_sequence([torch.tensor(f.choices_features['input_ids'],
|
| 266 |
+
dtype=torch.long)
|
| 267 |
+
for f in features],
|
| 268 |
+
batch_first=True, padding_value=0)
|
| 269 |
+
position_ids = pad_sequence(
|
| 270 |
+
[torch.tensor(f.choices_features['position_ids'], dtype=torch.long)
|
| 271 |
+
for f in features],
|
| 272 |
+
batch_first=True, padding_value=0)
|
| 273 |
+
token_type_ids = pad_sequence(
|
| 274 |
+
[torch.tensor(f.choices_features['token_type_ids'],
|
| 275 |
+
dtype=torch.long)
|
| 276 |
+
for f in features],
|
| 277 |
+
batch_first=True, padding_value=0)
|
| 278 |
+
labels = pad_sequence([torch.tensor(f.lm_labels, dtype=torch.long)
|
| 279 |
+
for f in features],
|
| 280 |
+
batch_first=True, padding_value=-1)
|
| 281 |
+
context_len = torch.tensor([f.context_len for f in features],
|
| 282 |
+
dtype=torch.long)
|
| 283 |
+
response_len = torch.tensor([f.response_len for f in features],
|
| 284 |
+
dtype=torch.long)
|
| 285 |
+
return (input_ids, position_ids, token_type_ids, labels,
|
| 286 |
+
context_len, response_len)
|
| 287 |
+
|
| 288 |
+
def get_len(self, corpus):
|
| 289 |
+
n_line = int(sp.check_output(f"wc -l {corpus}".split(),
|
| 290 |
+
universal_newlines=True).split()[0])
|
| 291 |
+
return n_line
|
demo.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
#
|
| 4 |
+
# Please assign the DATA_FOLDER before running this scripts, the data, pre-trained model, fine-tuned model will be
|
| 5 |
+
# downloaded automatically to DATA_FOLDER
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import logging
|
| 10 |
+
from functools import partial
|
| 11 |
+
|
| 12 |
+
from demo_utils import download_model_folder
|
| 13 |
+
import argparse
|
| 14 |
+
import subprocess as sp
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
PROJECT_FOLDER = os.path.dirname(os.path.realpath(__file__))
|
| 18 |
+
PYTHON_EXE = 'python'
|
| 19 |
+
MODEL_FOLDER = os.path.join(PROJECT_FOLDER, 'models')
|
| 20 |
+
DATA_FOLDER = os.path.join(PROJECT_FOLDER, 'data')
|
| 21 |
+
|
| 22 |
+
print(f'PROJECT_FOLDER = {PROJECT_FOLDER}')
|
| 23 |
+
|
| 24 |
+
parser = argparse.ArgumentParser()
|
| 25 |
+
parser.add_argument('--data', type=str, default='dummy',
|
| 26 |
+
help='choose from dummy, small and full')
|
| 27 |
+
dargs = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
assert dargs.data == 'dummy' or dargs.data == 'small' or dargs.data == 'full' , \
|
| 30 |
+
'The specified data option is not support!'
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logging.basicConfig(
|
| 34 |
+
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
| 35 |
+
datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO
|
| 36 |
+
)
|
| 37 |
+
logger = logging.getLogger(__name__)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if os.path.exists(MODEL_FOLDER):
|
| 41 |
+
print(f'Found existing models folder at {MODEL_FOLDER}, skip creating a new one!')
|
| 42 |
+
os.makedirs(MODEL_FOLDER, exist_ok=True)
|
| 43 |
+
else:
|
| 44 |
+
os.makedirs(MODEL_FOLDER)
|
| 45 |
+
|
| 46 |
+
#########################################################################
|
| 47 |
+
# Download Model
|
| 48 |
+
#########################################################################
|
| 49 |
+
logger.info('Downloading models...')
|
| 50 |
+
download_model = partial(download_model_folder, DATA_FOLDER=MODEL_FOLDER)
|
| 51 |
+
|
| 52 |
+
# model size: could be one of 'small' (GPT2 with 117M), 'medium'(345M) or 'large' (1542M)
|
| 53 |
+
# dataset: one of 'multiref' or 'dstc'
|
| 54 |
+
# from_scratch: True : load model trained from scratch or False: load model trained from fine-tuning the GPT-2
|
| 55 |
+
target_folder = download_model(model_size='small', dataset='multiref', from_scratch=False)
|
| 56 |
+
logger.info('Done!\n')
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
#########################################################################
|
| 60 |
+
# Prepare Data
|
| 61 |
+
#########################################################################
|
| 62 |
+
logger.info('Downloading and Extracting Data...')
|
| 63 |
+
if dargs.data == 'dummy':
|
| 64 |
+
cmd = 'bash prepare4db.sh'
|
| 65 |
+
ret = sp.run(cmd.split(' '), stdout=sp.PIPE, stderr=sp.STDOUT, cwd=DATA_FOLDER)
|
| 66 |
+
elif dargs.data == 'small':
|
| 67 |
+
myCmd = os.popen('cd reddit_extractor; SIZE=small make -j 8; cd ..').read()
|
| 68 |
+
elif dargs.data == 'full':
|
| 69 |
+
myCmd = os.popen('cd reddit_extractor; SIZE=full make -j 8; cd ..').read()
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError('you need to implement your own data type, or use either dummy, small, or full')
|
| 72 |
+
|
| 73 |
+
logger.info('Preparing Data...')
|
| 74 |
+
data_path = os.path.join(DATA_FOLDER, 'train.tsv')
|
| 75 |
+
MAX_LEN = 128
|
| 76 |
+
data_db = f'{data_path[:-4]}.{MAX_LEN}len.db'
|
| 77 |
+
if os.path.isdir(data_db):
|
| 78 |
+
print(f'{data_db} exists, skip prepro.py')
|
| 79 |
+
else:
|
| 80 |
+
cmd = ['prepro.py', '--corpus', data_path, '--max_seq_len', f'{MAX_LEN}']
|
| 81 |
+
cmd = ' '.join(cmd) #% {'CODE_ROOT': CODE_ROOT}
|
| 82 |
+
print(cmd)
|
| 83 |
+
ret = sp.run([PYTHON_EXE] + cmd.split(' '), stdout=sp.PIPE, stderr=sp.STDOUT, cwd=PROJECT_FOLDER)
|
| 84 |
+
if ret.returncode != 0:
|
| 85 |
+
print(f'error occurred, {ret.stdout}')
|
| 86 |
+
sys.exit(ret.returncode)
|
| 87 |
+
logger.info('Done!\n')
|
| 88 |
+
|
| 89 |
+
#########################################################################
|
| 90 |
+
# Train !
|
| 91 |
+
#########################################################################
|
| 92 |
+
logger.info('Generating training CMD!')
|
| 93 |
+
logger.info('If there is any problem, please copy (modify) and run command below')
|
| 94 |
+
logger.info('#########################################################################')
|
| 95 |
+
train_cmd = 'LSP_train.py'
|
| 96 |
+
args = [
|
| 97 |
+
'--model_name_or_path', target_folder,
|
| 98 |
+
'--init_checkpoint', os.path.join(target_folder, 'pytorch_model.bin'),
|
| 99 |
+
'--train_input_file', data_db , # file from last step
|
| 100 |
+
'--eval_input_file', './data/dummy_data.tsv', # dummy test data
|
| 101 |
+
'--output_dir', os.path.join(MODEL_FOLDER, 'output_model'),
|
| 102 |
+
'--seed', '42',
|
| 103 |
+
'--max_seq_length', '128',
|
| 104 |
+
'--train_batch_size', '512',
|
| 105 |
+
'--gradient_accumulation_steps', '8',
|
| 106 |
+
'--eval_batch_size', '64',
|
| 107 |
+
'--learning_rate', '1e-5',
|
| 108 |
+
'--num_optim_steps', '10000',
|
| 109 |
+
'--valid_step', '5000',
|
| 110 |
+
'--warmup_steps', '4000',
|
| 111 |
+
'--normalize_data', 'true',
|
| 112 |
+
'--fp16', 'true',
|
| 113 |
+
'--lr_schedule', 'noam',
|
| 114 |
+
'--loss_scale', '0.0',
|
| 115 |
+
'--no_token_id', 'true',
|
| 116 |
+
'--pbar', 'true'
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
arg = ' '.join(args)
|
| 120 |
+
train_cmd = train_cmd + ' ' + arg
|
| 121 |
+
print(PYTHON_EXE + ' ' +train_cmd)
|
| 122 |
+
logger.info('#########################################################################')
|
| 123 |
+
with open('./output.log', 'wb') as f:
|
| 124 |
+
process = sp.Popen([PYTHON_EXE] + train_cmd.split(' '), stdout=sp.PIPE, stderr=sp.STDOUT, cwd=PROJECT_FOLDER)
|
| 125 |
+
for line in iter(process.stdout.readline, b''):
|
| 126 |
+
sys.stdout.write(line.decode(sys.stdout.encoding))
|
| 127 |
+
f.write(line)
|
| 128 |
+
logger.info('Done!\n')
|
demo_utils.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
from pytorch_pretrained_bert.file_utils import http_get
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Note that the model size is roughly half of the GPT model because our model is saved by fp16
|
| 14 |
+
LSP_MODEL_URL = {
|
| 15 |
+
'multiref': {
|
| 16 |
+
'large_fs': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/large_fs.pkl',
|
| 17 |
+
'medium_fs': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/medium_fs.pkl',
|
| 18 |
+
'medium_ft': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/medium_ft.pkl',
|
| 19 |
+
'small_fs': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/small_fs.pkl',
|
| 20 |
+
'small_ft': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/multiref/small_ft.pkl'
|
| 21 |
+
},
|
| 22 |
+
'dstc': {
|
| 23 |
+
'medium_ft': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/DSTC/medium_ft.pkl'
|
| 24 |
+
}
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
# GPT model could be downloaded from huggingface repo
|
| 28 |
+
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
| 29 |
+
"small": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
|
| 30 |
+
"medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
|
| 31 |
+
"large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin"
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
CONFIG_FILE = {
|
| 35 |
+
'small': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/117M/config.json',
|
| 36 |
+
'medium': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/345M/config.json',
|
| 37 |
+
'large': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/1542M/config.json'
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
VOCAB_FILE = {
|
| 41 |
+
'small': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/117M/vocab.json',
|
| 42 |
+
'medium': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/345M/vocab.json',
|
| 43 |
+
'large': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/1542M/vocab.json'
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
MERGE_FILE = {
|
| 47 |
+
'small': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/117M/merges.txt',
|
| 48 |
+
'medium': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/345M/merges.txt',
|
| 49 |
+
'large': 'https://acvrpublicycchen.blob.core.windows.net/dialogpt/1542M/merges.txt'
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def download_file(url, folder):
|
| 54 |
+
if not os.path.exists(folder):
|
| 55 |
+
os.makedirs(folder, exist_ok=True)
|
| 56 |
+
|
| 57 |
+
file_name = os.path.basename(url)
|
| 58 |
+
if 'pytorch_model.bin' in file_name:
|
| 59 |
+
file_name = 'pytorch_model.bin'
|
| 60 |
+
|
| 61 |
+
if os.path.isfile(os.path.join(folder, file_name)):
|
| 62 |
+
logger.info(f'{os.path.join(folder, file_name)} exists, return!')
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
with open(os.path.join(folder, file_name), 'wb') as f:
|
| 66 |
+
http_get(url, f)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def download_model_folder(model_size, dataset=None, from_scratch=None, DATA_FOLDER=None):
|
| 70 |
+
assert DATA_FOLDER is not None, 'DATA_FOLDER cannot be None'
|
| 71 |
+
assert model_size in ['small', 'medium', 'large'], 'model size should be one of \'small\', \'medium\' or \'large\''
|
| 72 |
+
target_folder = os.path.join(DATA_FOLDER, model_size)
|
| 73 |
+
download_file(CONFIG_FILE[model_size], target_folder)
|
| 74 |
+
download_file(VOCAB_FILE[model_size], target_folder)
|
| 75 |
+
download_file(MERGE_FILE[model_size], target_folder)
|
| 76 |
+
download_file(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP[model_size], target_folder)
|
| 77 |
+
if dataset is not None:
|
| 78 |
+
assert dataset in ['multiref', 'dstc'], \
|
| 79 |
+
'dataset has to be \'multiref\' or \'dstc\''
|
| 80 |
+
assert from_scratch in [True, False], 'from scratch has to be True or False'
|
| 81 |
+
|
| 82 |
+
if from_scratch:
|
| 83 |
+
model_train_type = model_size + '_fs'
|
| 84 |
+
else:
|
| 85 |
+
model_train_type = model_size + '_ft'
|
| 86 |
+
if model_train_type not in LSP_MODEL_URL[dataset]:
|
| 87 |
+
k = ','.join(list(LSP_MODEL_URL[dataset].keys()))
|
| 88 |
+
raise ValueError(f'\'{model_train_type}\' not exist for dataset \'{dataset}\', please choose from [{k}]')
|
| 89 |
+
download_file(LSP_MODEL_URL[dataset][model_train_type], target_folder)
|
| 90 |
+
return target_folder
|
| 91 |
+
|
env.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
END_OF_TURN_TOKEN = '<|endofturn|>'
|
| 7 |
+
END_OF_TEXT_TOKEN = '<|endoftext|>'
|
| 8 |
+
PROJECT_FOLDER = os.path.dirname(__file__)
|
gradiodemo.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A 3rd party demo contributed by Github user AK391 (https://github.com/AK391). This is not implemented by Microsoft and Microsoft do not own any IP with this implementation and associated demo.
|
| 2 |
+
# Microsoft has not tested the generation of this demo and is not responsible for any offensive or biased generation from this demo.
|
| 3 |
+
# Please contact the creator AK391 (https://github.com/AK391) for any potential issue.
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 7 |
+
import torch
|
| 8 |
+
import gradio as gr
|
| 9 |
+
|
| 10 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
|
| 11 |
+
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
|
| 12 |
+
|
| 13 |
+
def dialogpt(text):
|
| 14 |
+
# encode the new user input, add the eos_token and return a tensor in Pytorch
|
| 15 |
+
for step in range(50000):
|
| 16 |
+
|
| 17 |
+
new_user_input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors='pt')
|
| 18 |
+
|
| 19 |
+
# append the new user input tokens to the chat history
|
| 20 |
+
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
|
| 21 |
+
|
| 22 |
+
# generated a response while limiting the total chat history to 1000 tokens,
|
| 23 |
+
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
|
| 24 |
+
|
| 25 |
+
# pretty print last ouput tokens from bot
|
| 26 |
+
return tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|
| 27 |
+
|
| 28 |
+
inputs = gr.inputs.Textbox(lines=1, label="Input Text")
|
| 29 |
+
outputs = gr.outputs.Textbox(label="DialoGPT")
|
| 30 |
+
|
| 31 |
+
title = "DialoGPT"
|
| 32 |
+
description = "demo for Microsoft DialoGPT with Hugging Face transformers. To use it, simply input text or click one of the examples text to load them. Read more at the links below. *This is not a Microsoft product and is developed for Gradio*"
|
| 33 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1911.00536'>DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation</a> | <a href='https://github.com/microsoft/DialoGPT'>Github Repo</a> | <a href='https://huggingface.co/microsoft/DialoGPT-large'>Hugging Face DialoGPT-large</a></p>"
|
| 34 |
+
examples = [
|
| 35 |
+
["Hi, how are you?"],
|
| 36 |
+
["How far away is the moon?"],
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
gr.Interface(dialogpt, inputs, outputs, title=title, description=description, article=article, examples=examples).launch(debug=True)
|
prepro.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
"""
|
| 4 |
+
preprocess input data into feature and stores binary as python shelve DB
|
| 5 |
+
each chunk is gzipped JSON string
|
| 6 |
+
"""
|
| 7 |
+
import argparse
|
| 8 |
+
import gzip
|
| 9 |
+
import json
|
| 10 |
+
import subprocess as sp
|
| 11 |
+
import shelve
|
| 12 |
+
import os
|
| 13 |
+
from os.path import dirname, exists, join
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from lsp_model import GPT2Tokenizer
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
from env import END_OF_TEXT_TOKEN
|
| 20 |
+
from gpt2_training.train_utils import InputFeatures_train as InputFeatures
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _get_file_len(corpus):
|
| 24 |
+
n_line = int(sp.check_output(f"wc -l {corpus}".split(),
|
| 25 |
+
universal_newlines=True).split()[0])
|
| 26 |
+
return n_line
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _norm_text(text):
|
| 30 |
+
w, *toks = text.strip().split()
|
| 31 |
+
try:
|
| 32 |
+
w = float(w)
|
| 33 |
+
except Exception:
|
| 34 |
+
toks = [w] + toks
|
| 35 |
+
w = 1.0
|
| 36 |
+
return w, ' '.join(toks)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _get_inputs_from_text(text, tokenizer):
|
| 40 |
+
srcs, tgt = text.strip().split('\t')
|
| 41 |
+
weights = []
|
| 42 |
+
inputs = []
|
| 43 |
+
for src in srcs.split(' EOS '):
|
| 44 |
+
src_weight, src = _norm_text(src)
|
| 45 |
+
context_id = tokenizer.encode(src)
|
| 46 |
+
weights.append(src_weight)
|
| 47 |
+
inputs.append(context_id)
|
| 48 |
+
tgt_weight, tgt = _norm_text(tgt)
|
| 49 |
+
if tgt_weight != 0:
|
| 50 |
+
response_id = tokenizer.encode(tgt)
|
| 51 |
+
weights.append(tgt_weight)
|
| 52 |
+
inputs.append(response_id)
|
| 53 |
+
return weights, inputs
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _make_features(id_, weights, inputs, tokenizer, max_len):
|
| 57 |
+
end_of_text_id = tokenizer.encoder[END_OF_TEXT_TOKEN]
|
| 58 |
+
features = []
|
| 59 |
+
sents = []
|
| 60 |
+
ws = []
|
| 61 |
+
len_ = 0
|
| 62 |
+
i = 0
|
| 63 |
+
for ids, w in zip(inputs, weights):
|
| 64 |
+
if len(ids) > max_len:
|
| 65 |
+
if len(sents) >= 2:
|
| 66 |
+
feat = _make_feature(id_ + i, sents, ws, end_of_text_id)
|
| 67 |
+
if feat is not None:
|
| 68 |
+
features.append(feat)
|
| 69 |
+
i += 1
|
| 70 |
+
len_ = 0
|
| 71 |
+
sents = []
|
| 72 |
+
ws = []
|
| 73 |
+
continue
|
| 74 |
+
elif len_ > max_len:
|
| 75 |
+
feat = _make_feature(id_ + i, sents, ws, end_of_text_id)
|
| 76 |
+
if feat is not None:
|
| 77 |
+
features.append(feat)
|
| 78 |
+
i += 1
|
| 79 |
+
len_ = len(sents[-1]) + 1
|
| 80 |
+
sents = sents[-1:]
|
| 81 |
+
ws = ws[-1:]
|
| 82 |
+
len_ += (len(ids) + 1)
|
| 83 |
+
sents.append(ids)
|
| 84 |
+
ws.append(w)
|
| 85 |
+
if len(sents) >= 2:
|
| 86 |
+
feat = _make_feature(id_ + i, sents, ws, end_of_text_id)
|
| 87 |
+
if feat is not None:
|
| 88 |
+
features.append(feat)
|
| 89 |
+
|
| 90 |
+
return features
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _make_feature(id_, sents, ws, eos):
|
| 94 |
+
if all(w == 0 for w in ws[1:]):
|
| 95 |
+
return None
|
| 96 |
+
input_ids = [i for s in sents for i in s+[eos]][:-1]
|
| 97 |
+
lm_labels = []
|
| 98 |
+
weights = []
|
| 99 |
+
token_type_ids = [] # this becomes round ids
|
| 100 |
+
for i, (s, w) in enumerate(zip(sents, ws)):
|
| 101 |
+
if i == 0:
|
| 102 |
+
lm_labels += [-1] * len(s)
|
| 103 |
+
weights += [0.0] * len(s)
|
| 104 |
+
token_type_ids += [0] * len(s)
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
token_type_ids += [i] * (len(s) + 1)
|
| 108 |
+
if w == 0.0:
|
| 109 |
+
lm_labels += [-1] * (len(s) + 1)
|
| 110 |
+
weights += [0.0] * (len(s) + 1)
|
| 111 |
+
else:
|
| 112 |
+
lm_labels += (s + [eos])
|
| 113 |
+
weights += [w] * (len(s) + 1)
|
| 114 |
+
|
| 115 |
+
# handle trailing -1's
|
| 116 |
+
i = len(lm_labels) - 1
|
| 117 |
+
while i >= 0:
|
| 118 |
+
if lm_labels[i] != -1:
|
| 119 |
+
break
|
| 120 |
+
i -= 1
|
| 121 |
+
input_ids = input_ids[:i+1]
|
| 122 |
+
lm_labels = lm_labels[:i+1]
|
| 123 |
+
weights = weights[:i+1]
|
| 124 |
+
token_type_ids = token_type_ids[:i+1]
|
| 125 |
+
|
| 126 |
+
# pad to multiples of 8
|
| 127 |
+
while len(input_ids) % 8 != 0:
|
| 128 |
+
input_ids.append(0)
|
| 129 |
+
token_type_ids.append(0)
|
| 130 |
+
lm_labels.append(-1)
|
| 131 |
+
weights.append(0.0)
|
| 132 |
+
|
| 133 |
+
position_ids = list(range(len(input_ids)))
|
| 134 |
+
assert (len(input_ids) == len(position_ids) == len(token_type_ids)
|
| 135 |
+
== len(lm_labels) == len(weights))
|
| 136 |
+
assert len(input_ids) % 8 == 0
|
| 137 |
+
if len(input_ids) == 0:
|
| 138 |
+
import pdb
|
| 139 |
+
pdb.set_trace()
|
| 140 |
+
feature = InputFeatures(id_, input_ids, position_ids, token_type_ids,
|
| 141 |
+
lm_labels, weights)
|
| 142 |
+
return feature
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def main(args):
|
| 146 |
+
toker = GPT2Tokenizer.from_pretrained('gpt2')
|
| 147 |
+
attrs = []
|
| 148 |
+
if args.reverse:
|
| 149 |
+
attrs.append('reverse')
|
| 150 |
+
if args.two_turn:
|
| 151 |
+
attrs.append('2turn')
|
| 152 |
+
if attrs:
|
| 153 |
+
db_path = (f'{args.corpus[:-4]}.{args.max_seq_len}len.'
|
| 154 |
+
f'{".".join(attrs)}.db/db')
|
| 155 |
+
else:
|
| 156 |
+
db_path = f'{args.corpus[:-4]}.{args.max_seq_len}len.db/db'
|
| 157 |
+
if exists(dirname(db_path)):
|
| 158 |
+
raise ValueError('Found existing DB, please backup')
|
| 159 |
+
else:
|
| 160 |
+
os.makedirs(dirname(db_path))
|
| 161 |
+
with open(args.corpus, "r", encoding="utf-8") as reader, \
|
| 162 |
+
shelve.open(db_path, 'n') as db:
|
| 163 |
+
chunk = []
|
| 164 |
+
n_chunk = 0
|
| 165 |
+
n_example = 0
|
| 166 |
+
for line in tqdm(reader, total=_get_file_len(args.corpus)):
|
| 167 |
+
try:
|
| 168 |
+
if len(chunk) >= args.chunk_size:
|
| 169 |
+
# save and renew chunk
|
| 170 |
+
db[f'chunk_{n_chunk}'] = gzip.compress(
|
| 171 |
+
json.dumps(chunk[:args.chunk_size]).encode('utf-8'))
|
| 172 |
+
chunk = chunk[args.chunk_size:]
|
| 173 |
+
n_chunk += 1
|
| 174 |
+
|
| 175 |
+
weights, inputs = _get_inputs_from_text(line, toker)
|
| 176 |
+
if args.reverse:
|
| 177 |
+
weights = list(reversed(weights))
|
| 178 |
+
inputs = list(reversed(inputs))
|
| 179 |
+
if args.two_turn:
|
| 180 |
+
weights = weights[:2]
|
| 181 |
+
inputs = inputs[:2]
|
| 182 |
+
if len(weights) < 2:
|
| 183 |
+
continue
|
| 184 |
+
features = _make_features(n_example, weights, inputs,
|
| 185 |
+
toker, args.max_seq_len)
|
| 186 |
+
for feature in features:
|
| 187 |
+
chunk.append(vars(feature))
|
| 188 |
+
n_example += 1
|
| 189 |
+
except Exception as e:
|
| 190 |
+
print('!!! prepro exception !!!', e)
|
| 191 |
+
continue
|
| 192 |
+
# save last chunk
|
| 193 |
+
db[f'chunk_{n_chunk}'] = gzip.compress(
|
| 194 |
+
json.dumps(chunk).encode('utf-8'))
|
| 195 |
+
# save relevant information to reproduce
|
| 196 |
+
meta = {'n_example': n_example,
|
| 197 |
+
'chunk_size': args.chunk_size,
|
| 198 |
+
'max_seq_len': args.max_seq_len,
|
| 199 |
+
'reverse': args.reverse,
|
| 200 |
+
'two_turn': args.two_turn}
|
| 201 |
+
with open(join(dirname(db_path), 'meta.json'), 'w') as writer:
|
| 202 |
+
json.dump(meta, writer, indent=4)
|
| 203 |
+
torch.save(toker, join(dirname(db_path), 'tokenizer.pt'))
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
if __name__ == '__main__':
|
| 207 |
+
parser = argparse.ArgumentParser()
|
| 208 |
+
parser.add_argument('--corpus', required=True,
|
| 209 |
+
help='file name of training corpus (should be .tsv)')
|
| 210 |
+
parser.add_argument('--chunk_size', type=int, default=65536,
|
| 211 |
+
help='num of data examples in a storing chunk')
|
| 212 |
+
parser.add_argument('--max_seq_len', type=int, default=128,
|
| 213 |
+
help='discard data longer than this')
|
| 214 |
+
parser.add_argument('--reverse', action='store_true',
|
| 215 |
+
help='reverse the src tgt')
|
| 216 |
+
parser.add_argument('--two_turn', action='store_true',
|
| 217 |
+
help='take only the first 2 turns')
|
| 218 |
+
|
| 219 |
+
args = parser.parse_args()
|
| 220 |
+
|
| 221 |
+
main(args)
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers
|
| 2 |
+
ftfy
|
| 3 |
+
regex
|
| 4 |
+
tqdm
|
| 5 |
+
torch
|
| 6 |
+
torchvision
|
| 7 |
+
pycrypto
|
| 8 |
+
flashtext
|
| 9 |
+
zstandard
|
| 10 |
+
nltk
|