File size: 4,307 Bytes
5fbd25d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# -*- coding: utf-8 -*-

""" Some tools

@file: tools.py
@author: Konie
@update: 2024-03-22
"""
# pylint: disable=line-too-long
# pylint: disable=broad-exception-caught
import os
import sys
import re
import subprocess
from importlib.util import find_spec
from importlib import metadata
from packaging import version


PYTHON_EXEC = sys.executable
INDEX_URL = os.environ.get('INDEX_URL', "")
PATTERN = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")


# This function was copied from [Fooocus](https://github.com/lllyasviel/Fooocus) repository.
def run_command(command: str,
                desc: str = None,
                error_desc: str = None,
                custom_env: str = None,
                live: bool = True) -> str:
    """
    Run a command and return the output
    Args:
        command: Command to run
        desc: Description of the command
        error_desc: Description of the error
        custom_env: Custom environment variables
        live: Whether to print the output
    Returns:
        The output of the command
    """
    if desc is not None:
        print(desc)

    run_kwargs = {
        "args": command,
        "shell": True,
        "env": os.environ if custom_env is None else custom_env,
        "encoding": 'utf8',
        "errors": 'ignore'
    }

    if not live:
        run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE

    result = subprocess.run(check=False, **run_kwargs)

    if result.returncode != 0:
        error_bits = [
            f"{error_desc or 'Error running command'}.",
            f"Command: {command}",
            f"Error code: {result.returncode}",
        ]
        if result.stdout:
            error_bits.append(f"stdout: {result.stdout}")
        if result.stderr:
            error_bits.append(f"stderr: {result.stderr}")
        raise RuntimeError("\n".join(error_bits))

    return result.stdout or ""


# This function was copied from [Fooocus](https://github.com/lllyasviel/Fooocus) repository.
def run_pip(command, desc=None, live=True):
    """
    Run a pip command
    Args:
        command: Command to run
        desc: Description of the command
        live: Whether to print the output
    Returns:
        The output of the command
    """
    try:
        index_url_line = f' --index-url {INDEX_URL}' if INDEX_URL != '' else ''
        return run_command(
            command=f'"{PYTHON_EXEC}" -m pip {command} --prefer-binary{index_url_line}',
            desc=f"Installing {desc}",
            error_desc=f"Couldn't install {desc}",
            live=live
        )
    except Exception as e:
        print(f'CMD Failed {command}: {e}')
        return None


def is_installed(package: str) -> bool:
    """
    Check if a package is installed
    Args:
        package: Package name
    Returns:
        Whether the package is installed
    """
    try:
        spec = find_spec(package)
    except ModuleNotFoundError:
        return False

    return spec is not None


def check_torch_cuda() -> bool:
    """
    Check if torch and CUDA is available
    Returns:
        Whether CUDA is available
    """
    try:
        import torch
        return torch.cuda.is_available()
    except ImportError:
        return False


def requirements_check(requirements_file: str = 'requirements.txt',
                       pattern: re.Pattern = PATTERN) -> bool:
    """
    Check if the requirements file is satisfied
    Args:
        requirements_file: Path to the requirements file
        pattern: Pattern to match the requirements
    Returns:
        Whether the requirements file is satisfied
    """
    with open(requirements_file, "r", encoding="utf8") as file:
        for line in file:
            if line.strip() == "":
                continue

            m = re.match(pattern, line)
            if m is None:
                return False

            package = m.group(1).strip()
            version_required = (m.group(2) or "").strip()

            if version_required == "":
                continue

            try:
                version_installed = metadata.version(package)
            except Exception:
                return False

            if version.parse(version_required) != version.parse(version_installed):
                return False

    return True