File size: 4,477 Bytes
7055650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import argparse
import json
import os
import shlex
import shutil
import subprocess

from huggingface_hub import HfApi, hf_hub_download, snapshot_download
from huggingface_hub.utils._errors import EntryNotFoundError
from loguru import logger

from competitions import utils
from competitions.compute_metrics import compute_metrics
from competitions.enums import SubmissionStatus
from competitions.params import EvalParams


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    return parser.parse_args()


def upload_submission_file(params, file_path):
    logger.info("Uploading submission file")
    pass


def generate_submission_file(params):
    logger.info("Downloading submission dataset")
    submission_dir = snapshot_download(
        repo_id=params.submission_repo,
        local_dir=params.output_path,
        token=os.environ.get("USER_TOKEN"),
        repo_type="model",
    )
    # submission_dir has a script.py file
    # start a subprocess to run the script.py
    # the script.py will generate a submission.csv file in the submission_dir
    # push the submission.csv file to the repo using upload_submission_file
    logger.info("Generating submission file")

    # invalidate USER_TOKEN env var
    os.environ["USER_TOKEN"] = ""

    # Copy sandbox to submission_dir
    shutil.copyfile("sandbox", f"{submission_dir}/sandbox")
    sandbox_path = f"{submission_dir}/sandbox"
    os.chmod(sandbox_path, 0o755)
    os.chown(sandbox_path, os.getuid(), os.getgid())

    # Define your command
    cmd = f"{sandbox_path} python script.py"
    cmd = shlex.split(cmd)

    # Copy the current environment and modify it
    env = os.environ.copy()

    # Start the subprocess
    process = subprocess.Popen(cmd, cwd=submission_dir, env=env)

    # Wait for the process to complete or timeout
    try:
        process.wait(timeout=params.time_limit)
    except subprocess.TimeoutExpired:
        logger.info(f"Process exceeded {params.time_limit} seconds time limit. Terminating...")
        process.kill()
        process.wait()

    # Check if process terminated due to timeout
    if process.returncode and process.returncode != 0:
        logger.error("Subprocess didn't terminate successfully")
    else:
        logger.info("Subprocess terminated successfully")

    logger.info("contents of submission_dir")
    logger.info(os.listdir(submission_dir))

    api = HfApi(token=params.token)
    for sub_file in params.submission_filenames:
        logger.info(f"Uploading {sub_file} to the repository")
        sub_file_ext = sub_file.split(".")[-1]
        api.upload_file(
            path_or_fileobj=f"{submission_dir}/{sub_file}",
            path_in_repo=f"submissions/{params.team_id}-{params.submission_id}.{sub_file_ext}",
            repo_id=params.competition_id,
            repo_type="dataset",
        )


@utils.monitor
def run(params):
    logger.info(params)
    if isinstance(params, dict):
        params = EvalParams(**params)

    utils.update_submission_status(params, SubmissionStatus.PROCESSING.value)

    if params.competition_type == "script":
        try:
            requirements_fname = hf_hub_download(
                repo_id=params.competition_id,
                filename="requirements.txt",
                token=params.token,
                repo_type="dataset",
            )
        except EntryNotFoundError:
            requirements_fname = None

        if requirements_fname:
            logger.info("Installing requirements")
            utils.uninstall_requirements(requirements_fname)
            utils.install_requirements(requirements_fname)
        if len(str(params.dataset).strip()) > 0:
            # _ = Repository(local_dir="/tmp/data", clone_from=params.dataset, token=params.token)
            _ = snapshot_download(
                repo_id=params.dataset,
                local_dir="/tmp/data",
                token=params.token,
                repo_type="dataset",
            )
        generate_submission_file(params)

    evaluation = compute_metrics(params)

    utils.update_submission_score(params, evaluation["public_score"], evaluation["private_score"])
    utils.update_submission_status(params, SubmissionStatus.SUCCESS.value)
    utils.delete_space(params)


if __name__ == "__main__":
    args = parse_args()
    _params = json.load(open(args.config, encoding="utf-8"))
    _params = EvalParams(**_params)
    run(_params)