ia-nechaev HamidBekam commited on
Commit
d34d716
·
verified ·
1 Parent(s): bc78cab

Update prediction_multilabel.py (#13)

Browse files

- Update prediction_multilabel.py (6cf7ff8274c11e281210cc7372117f5d5a80f322)


Co-authored-by: Hamid Bekamiri <[email protected]>

Files changed (1) hide show
  1. prediction_multilabel.py +6 -4
prediction_multilabel.py CHANGED
@@ -5,20 +5,22 @@ import pandas as pd
5
  import torch
6
  import torch.nn as nn
7
  from sentence_transformers import util
 
8
 
9
  # Set random seed for reproducibility
10
  torch.manual_seed(1)
11
 
 
12
  # Load datasets
13
- df_inmemory = pd.read_csv('raw_data/labeled.csv') # labeled text extracted from 230 CSR GRI reports, 150 International companies, 2017-2021 period
14
- df_paragraph = pd.read_csv('raw_data/prediction_demo.csv') # paragraphs to predict the label, extracted from 1.2k CSR reports, 150 German PLC companies, 2010-2021 period, 645k paragraphs)
15
 
16
  # Load stored embeddings
17
- with open('embeddings/embeddings_prediction.pkl', "rb") as f:
18
  stored_data = pickle.load(f)
19
  pred_embeddings = stored_data['parg_embeddings']
20
 
21
- with open('embeddings/embeddings_labeled.pkl', "rb") as f:
22
  stored_data = pickle.load(f)
23
  embeddings = stored_data['sent_embeddings']
24
 
 
5
  import torch
6
  import torch.nn as nn
7
  from sentence_transformers import util
8
+ import os
9
 
10
  # Set random seed for reproducibility
11
  torch.manual_seed(1)
12
 
13
+ path = os.getcwd()
14
  # Load datasets
15
+ df_inmemory = pd.read_csv(path + '/raw_data/labeled.csv') # labeled text extracted from 230 CSR GRI reports, 150 International companies, 2017-2021 period
16
+ df_paragraph = pd.read_csv(path + '/raw_data/prediction_demo.csv', encoding='latin1')
17
 
18
  # Load stored embeddings
19
+ with open(path + '/embeddings/embeddings_prediction.pkl', "rb") as f:
20
  stored_data = pickle.load(f)
21
  pred_embeddings = stored_data['parg_embeddings']
22
 
23
+ with open(path + '/embeddings/embeddings_labeled.pkl', "rb") as f:
24
  stored_data = pickle.load(f)
25
  embeddings = stored_data['sent_embeddings']
26