File size: 538 Bytes
75a3e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import argparse
from src.model import AllMAssistant

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', default='gpt2')
    parser.add_argument('--prompt', default='Create a 10-minute beginner home workout for fat loss.')
    parser.add_argument('--max_new_tokens', type=int, default=200)
    args = parser.parse_args()

    assistant = AllMAssistant(args.model_dir)
    out = assistant.generate(args.prompt, max_new_tokens=args.max_new_tokens)
    print(out)

if __name__ == '__main__':
    main()