yyfz233 commited on
Commit
955bde4
·
1 Parent(s): 1a5d31c

Change weight loading method

Browse files
Files changed (2) hide show
  1. app.py +1 -3
  2. example.py +0 -4
app.py CHANGED
@@ -562,9 +562,7 @@ if __name__ == '__main__':
562
 
563
  print("Initializing and loading Pi3 model...")
564
 
565
- model = Pi3()
566
- _URL = "https://huggingface.co/yyfz233/Pi3/resolve/main/model.safetensors"
567
- model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
568
  # model.load_state_dict(torch.load('ckpts/pi3.pt', weights_only=False, map_location=device))
569
 
570
  model.eval()
 
562
 
563
  print("Initializing and loading Pi3 model...")
564
 
565
+ model = Pi3.from_pretrained("yyfz233/Pi3")
 
 
566
  # model.load_state_dict(torch.load('ckpts/pi3.pt', weights_only=False, map_location=device))
567
 
568
  model.eval()
example.py CHANGED
@@ -41,10 +41,6 @@ if __name__ == '__main__':
41
  model.load_state_dict(weight)
42
  else:
43
  model = Pi3.from_pretrained("yyfz233/Pi3").to(device).eval()
44
- # or
45
- # model = Pi3().to(device).eval()
46
- # _URL = "https://huggingface.co/yyfz233/Pi3/resolve/main/model.safetensors"
47
- # model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
48
 
49
  # 2. Prepare input data
50
  # The load_images_as_tensor function will print the loading path
 
41
  model.load_state_dict(weight)
42
  else:
43
  model = Pi3.from_pretrained("yyfz233/Pi3").to(device).eval()
 
 
 
 
44
 
45
  # 2. Prepare input data
46
  # The load_images_as_tensor function will print the loading path