Update prediction_multilabel.py (#13)
Browse files- Update prediction_multilabel.py (6cf7ff8274c11e281210cc7372117f5d5a80f322)
Co-authored-by: Hamid Bekamiri <[email protected]>
- prediction_multilabel.py +6 -4
prediction_multilabel.py
CHANGED
@@ -5,20 +5,22 @@ import pandas as pd
|
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
from sentence_transformers import util
|
|
|
8 |
|
9 |
# Set random seed for reproducibility
|
10 |
torch.manual_seed(1)
|
11 |
|
|
|
12 |
# Load datasets
|
13 |
-
df_inmemory = pd.read_csv('raw_data/labeled.csv') # labeled text extracted from 230 CSR GRI reports, 150 International companies, 2017-2021 period
|
14 |
-
df_paragraph = pd.read_csv('raw_data/prediction_demo.csv'
|
15 |
|
16 |
# Load stored embeddings
|
17 |
-
with open('embeddings/embeddings_prediction.pkl', "rb") as f:
|
18 |
stored_data = pickle.load(f)
|
19 |
pred_embeddings = stored_data['parg_embeddings']
|
20 |
|
21 |
-
with open('embeddings/embeddings_labeled.pkl', "rb") as f:
|
22 |
stored_data = pickle.load(f)
|
23 |
embeddings = stored_data['sent_embeddings']
|
24 |
|
|
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
from sentence_transformers import util
|
8 |
+
import os
|
9 |
|
10 |
# Set random seed for reproducibility
|
11 |
torch.manual_seed(1)
|
12 |
|
13 |
+
path = os.getcwd()
|
14 |
# Load datasets
|
15 |
+
df_inmemory = pd.read_csv(path + '/raw_data/labeled.csv') # labeled text extracted from 230 CSR GRI reports, 150 International companies, 2017-2021 period
|
16 |
+
df_paragraph = pd.read_csv(path + '/raw_data/prediction_demo.csv', encoding='latin1')
|
17 |
|
18 |
# Load stored embeddings
|
19 |
+
with open(path + '/embeddings/embeddings_prediction.pkl', "rb") as f:
|
20 |
stored_data = pickle.load(f)
|
21 |
pred_embeddings = stored_data['parg_embeddings']
|
22 |
|
23 |
+
with open(path + '/embeddings/embeddings_labeled.pkl', "rb") as f:
|
24 |
stored_data = pickle.load(f)
|
25 |
embeddings = stored_data['sent_embeddings']
|
26 |
|