feat: Windows Support (#53)

* feat: add os agnostic libray for existing shell command

* feat: add os agnostic libray for existing shell command, function tweaks and move to test folder

* feat: windows version of pre-flight script

* feat: replace existing shell command with os agnostic library

* [bug] windows upload files

* [bug] windows upload models

* merge conflicts

---------

Co-authored-by: yike5460 <yike5460@163.com>
pull/65/head
Yan 2023-07-10 17:37:53 +08:00 committed by GitHub
parent ec0de92ea9
commit 0e7be2955b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 867 additions and 41 deletions

View File

@ -23,6 +23,8 @@ from datetime import datetime
import math
import re
from utils import cp, tar, rm
inference_job_dropdown = None
textual_inversion_dropdown = None
hyperNetwork_dropdown = None
@ -196,7 +198,7 @@ def get_inference_job_list(txt2img_type_checkbox=True, img2img_type_checkbox=Tru
except Exception as e:
print("Exception occurred when fetching inference_job_ids")
return gr.Dropdown.update(choices=[])
@ -338,7 +340,7 @@ def refresh_all_models():
ckpt_type = ckpt["type"]
checkpoint_info[ckpt_type] = {}
for ckpt_name in ckpt["name"]:
ckpt_s3_pos = f"{ckpt['s3Location']}/{ckpt_name.split('/')[-1]}"
ckpt_s3_pos = f"{ckpt['s3Location']}/{ckpt_name.split(os.sep)[-1]}"
checkpoint_info[ckpt_type][ckpt_name] = ckpt_s3_pos
except Exception as e:
print(f"Error refresh all models: {e}")
@ -355,7 +357,7 @@ def sagemaker_upload_model_s3(sd_checkpoints_path, textual_inversion_path, lora_
if lp == "" or not lp:
continue
print(f"lp is {lp}")
model_name = lp.split("/")[-1]
model_name = lp.split(os.sep)[-1]
exist_model_list = list(checkpoint_info[rp].keys())
@ -403,23 +405,28 @@ def sagemaker_upload_model_s3(sd_checkpoints_path, textual_inversion_path, lora_
s3_signed_urls_resp = json_response["s3PresignUrl"][local_tar_path]
# Upload src model to S3.
if rp != "embeddings" :
local_model_path_in_repo = f'models/{rp}/{model_name}'
local_model_path_in_repo = os.sep.join(['models', rp, model_name])
else:
local_model_path_in_repo = f'{rp}/{model_name}'
local_model_path_in_repo = os.sep.join([rp, model_name])
#local_tar_path = f'{model_name}.tar'
print("Pack the model file.")
os.system(f"cp -f {lp} {local_model_path_in_repo}")
# os.system(f"cp -f {lp} {local_model_path_in_repo}")
cp(lp, local_model_path_in_repo, recursive=True)
if rp == "Stable-diffusion":
model_yaml_name = model_name.split('.')[0] + ".yaml"
local_model_yaml_path = "/".join(lp.split("/")[:-1]) + f"/{model_yaml_name}"
local_model_yaml_path_in_repo = f"models/{rp}/{model_yaml_name}"
local_model_yaml_path = os.sep.join([*lp.split(os.sep)[:-1], model_yaml_name])
local_model_yaml_path_in_repo = os.sep.join(["models", rp, model_yaml_name])
if os.path.isfile(local_model_yaml_path):
os.system(f"cp -f {local_model_yaml_path} {local_model_yaml_path_in_repo}")
os.system(f"tar cvf {local_tar_path} {local_model_path_in_repo} {local_model_yaml_path_in_repo}")
# os.system(f"cp -f {local_model_yaml_path} {local_model_yaml_path_in_repo}")
# os.system(f"tar cvf {local_tar_path} {local_model_path_in_repo} {local_model_yaml_path_in_repo}")
cp(local_model_yaml_path, local_model_yaml_path_in_repo, recursive=True)
tar(mode='c', archive=local_tar_path, sfiles=[local_model_path_in_repo, local_model_yaml_path_in_repo], verbose=True)
else:
os.system(f"tar cvf {local_tar_path} {local_model_path_in_repo}")
# os.system(f"tar cvf {local_tar_path} {local_model_path_in_repo}")
tar(mode='c', archive=local_tar_path, sfiles=[local_model_path_in_repo], verbose=True)
else:
os.system(f"tar cvf {local_tar_path} {local_model_path_in_repo}")
# os.system(f"tar cvf {local_tar_path} {local_model_path_in_repo}")
tar(mode='c', archive=local_tar_path, sfiles=[local_model_path_in_repo], verbose=True)
#upload_file_to_s3_by_presign_url(local_tar_path, s3_presigned_url)
multiparts_tags = upload_multipart_files_to_s3_by_signed_url(
local_tar_path,
@ -439,7 +446,8 @@ def sagemaker_upload_model_s3(sd_checkpoints_path, textual_inversion_path, lora_
log = f"\n finish upload {local_tar_path} to {s3_base}"
os.system(f"rm {local_tar_path}")
# os.system(f"rm {local_tar_path}")
rm(local_tar_path, recursive=True)
except Exception as e:
print(f"fail to upload model {lp}, error: {e}")
@ -753,7 +761,7 @@ def fake_gan(selected_value: str ):
if inference_job_status == 'inprogress':
return [], [], plaintext_to_html('inference still in progress')
if inference_job_taskType in ["txt2img", "img2img"]:
if inference_job_taskType in ["txt2img", "img2img"]:
prompt_txt = ''
images = get_inference_job_image_output(inference_job_id)
image_list = []
@ -786,7 +794,7 @@ def fake_gan(selected_value: str ):
prompt_txt = caption
image_list = [] # Return an empty list if selected_value is None
json_list = []
info_text = ''
info_text = ''
infotexts = ''
else:
prompt_txt = ''
@ -854,9 +862,9 @@ def create_ui(is_img2img):
with gr.Row():
global sagemaker_endpoint
sagemaker_endpoint = gr.Dropdown(sagemaker_endpoints,
label="Select Cloud SageMaker Endpoint",
elem_id="sagemaker_endpoint_dropdown"
)
label="Select Cloud SageMaker Endpoint",
elem_id="sagemaker_endpoint_dropdown"
)
modules.ui.create_refresh_button(sagemaker_endpoint, update_sagemaker_endpoints, lambda: {"choices": sagemaker_endpoints}, "refresh_sagemaker_endpoints")
with gr.Row():

View File

@ -22,6 +22,8 @@ os.environ['IGNORE_CMD_ARGS_ERRORS'] = ""
from dreambooth.ui_functions import start_training
from dreambooth.shared import status
from utls import tar, mv
def sync_status_from_s3_json(bucket_name, webui_status_file_path, sagemaker_status_file_path):
while True:
time.sleep(1)
@ -98,19 +100,16 @@ def upload_model_to_s3_v2(model_name, s3_output_path, model_type):
yaml = os.path.join(root, f"{ckpt_name}.yaml")
output_tar = file
tar_command = f"tar cvf {output_tar} {safetensors} {yaml}"
print(tar_command)
os.system(tar_command)
logger.info(f"Upload check point to s3 {output_tar} {output_bucket_name} {s3_output_path}")
print(f"Upload check point to s3 {output_tar} {output_bucket_name} {s3_output_path}")
upload_file_to_s3(output_tar, output_bucket_name, os.path.join(s3_output_path, model_name))
elif model_type == "Lora":
output_tar = file
tar_command = f"tar cvf {output_tar} {safetensors}"
print(tar_command)
os.system(tar_command)
logger.info(f"Upload check point to s3 {output_tar} {output_bucket_name} {s3_output_path}")
print(f"Upload check point to s3 {output_tar} {output_bucket_name} {s3_output_path}")
upload_file_to_s3(output_tar, output_bucket_name, s3_output_path)
print(tar_command)
# os.system(tar_command)
tar(mode='c', archive=output_tar, sfiles=[safetensors, yaml], verbose=True)
logger.info(f"Upload check point to s3 {output_tar} {output_bucket_name} {s3_output_path}")
print(f"Upload check point to s3 {output_tar} {output_bucket_name} {s3_output_path}")
upload_file_to_s3(output_tar, output_bucket_name, s3_output_path)
def download_data(data_list, s3_data_path_list, s3_input_path):
for data, data_tar in zip(data_list, s3_data_path_list):
@ -145,7 +144,8 @@ def prepare_for_training(s3_model_path, model_name, s3_input_path, data_tar_list
download_db_config_path = f"models/sagemaker_dreambooth/{model_name}/db_config_cloud.json"
target_db_config_path = f"models/dreambooth/{model_name}/db_config.json"
logger.info(f"Move db_config to correct position {download_db_config_path} {target_db_config_path}")
os.system(f"mv {download_db_config_path} {target_db_config_path}")
# os.system(f"mv {download_db_config_path} {target_db_config_path}")
mv(download_db_config_path, target_db_config_path)
with open(target_db_config_path) as db_config_file:
db_config = json.load(db_config_file)
data_list = []
@ -171,7 +171,8 @@ def prepare_for_training_v2(s3_model_path, model_name, s3_input_path, s3_data_pa
download_db_config_path = f"models/sagemaker_dreambooth/{model_name}/db_config_cloud.json"
target_db_config_path = f"models/dreambooth/{model_name}/db_config.json"
logger.info(f"Move db_config to correct position {download_db_config_path} {target_db_config_path}")
os.system(f"mv {download_db_config_path} {target_db_config_path}")
# os.system(f"mv {download_db_config_path} {target_db_config_path}")
mv(download_db_config_path, target_db_config_path)
with open(target_db_config_path) as db_config_file:
db_config = json.load(db_config_file)
data_list = []

View File

@ -10,8 +10,8 @@ import logging
from modules import sd_models
from utils import upload_multipart_files_to_s3_by_signed_url
from utils import get_variable_from_json
from utils import tar
import gradio as gr
logging.basicConfig(filename='sd-aws-ext.log', level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
@ -144,7 +144,8 @@ def async_create_model_on_sagemaker(
multiparts_tags=[]
if not from_hub:
print("Pack the model file.")
os.system(f"tar cvf {local_tar_path} {local_model_path}")
# os.system(f"tar cvf {local_tar_path} {local_model_path}")
tar(mode='c', archive=local_tar_path, sfiles=[local_model_path], verbose=True)
s3_base = json_response["job"]["s3_base"]
print(f"Upload to S3 {s3_base}")
print(f"Model ID: {model_id}")

View File

@ -9,6 +9,7 @@ import logging
import shutil
from utils import upload_file_to_s3_by_presign_url
from utils import get_variable_from_json
from utils import tar, cp
logging.basicConfig(filename='sd-aws-ext.log', level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
@ -114,7 +115,8 @@ def async_prepare_for_training_on_sagemaker(
url += "train"
upload_files = []
db_config_tar = f"db_config.tar"
os.system(f"tar cvf {db_config_tar} {db_config_path}")
# os.system(f"tar cvf {db_config_tar} {db_config_path}")
tar(mode='c', archive=db_config_tar, sfiles=db_config_path, verbose=True)
upload_files.append(db_config_tar)
new_data_list = []
for data_path in data_path_list:
@ -125,7 +127,8 @@ def async_prepare_for_training_on_sagemaker(
data_tar = f'data-{data_path.replace("/", "-").strip("-")}.tar'
new_data_list.append(data_tar)
print("Pack the data file.")
os.system(f"tar cf {data_tar} {data_path}")
# os.system(f"tar cf {data_tar} {data_path}")
tar(mode='c', archive=data_tar, sfiles=data_path, verbose=False)
upload_files.append(data_tar)
else:
new_data_list.append(data_path)
@ -139,7 +142,8 @@ def async_prepare_for_training_on_sagemaker(
new_class_data_list.append(class_data_tar)
upload_files.append(class_data_tar)
print("Pack the class data file.")
os.system(f"tar cf {class_data_tar} {class_data_path}")
# os.system(f"tar cf {class_data_tar} {class_data_path}")
tar(mode='c', archive=class_data_tar, sfiles=[class_data_path], verbose=False)
else:
new_class_data_list.append(class_data_path)
payload = {

112
pre-flight.bat Normal file
View File

@ -0,0 +1,112 @@
@echo off
SETLOCAL ENABLEDELAYEDEXPANSION
set INITIAL_SUPPORT_COMMIT_ROOT=89f9faa6
set INITIAL_SUPPORT_COMMIT_CONTROLNET=7c674f83
set INITIAL_SUPPORT_COMMIT_DREAMBOOTH=926ae204
set REPO_URL_LIST="https://github.com/Mikubill/sd-webui-controlnet.git https://github.com/d8ahazard/sd_dreambooth_extension.git"
set REPO_FOLDER_LIST="sd-webui-controlnet sd_dreambooth_extension"
:show_help
echo Usage: %~nx0 -p/--pre-flight -s/--version-sync
goto :eof
:get_supported_commit_list
set repo_url=%1
set initial_support_commit=%2
set latest_commit=%3
for /f "tokens=*" %%i in ('git rev-list --topo-order %initial_support_commit%^..%latest_commit%') do echo %%i
goto :eof
:get_latest_commit_id
set repo_url=%1
for /f "tokens=1" %%i in ('git ls-remote "%repo_url%" HEAD ^| findstr /b "[0-9a-f]"') do set latest_commit_id=%%i
echo %latest_commit_id%
goto :eof
:pre_flight_check
echo Start pre-flight check for WebUI...
call :get_latest_commit_id "https://github.com/AUTOMATIC1111/stable-diffusion-webui.git"
set LATEST_ROOT_COMMIT=%latest_commit_id%
echo Supported commits for WebUI:
call :get_supported_commit_list "https://github.com/AUTOMATIC1111/stable-diffusion-webui.git" "%INITIAL_SUPPORT_COMMIT_ROOT%" "%LATEST_ROOT_COMMIT%"
set SUPPORTED_ROOT_COMMITS=%supported_commit_list%
for /f "tokens=*" %%i in ('git -C ..\.. rev-parse HEAD') do set CUR_ROOT_COMMIT=%%i
echo Current commit id for WebUI: %CUR_ROOT_COMMIT%
echo Pre-flight checks complete.
goto :eof
:version_sync
echo Start version sync for WebUI, make sure the extension folder is empty...
set extension_folder=%1
if not exist "%extension_folder%" (
echo The extension folder does not exist: %extension_folder%
echo Please create it and run the script again.
goto :eof
)
echo Syncing WebUI...
for %%r in (%REPO_URL_LIST%) do (
set repo_url=%%r
call :get_latest_commit_id !repo_url!
set latest_commit=!latest_commit_id!
for %%f in (%REPO_FOLDER_LIST%) do (
set repo_folder=%%f
if not exist "%extension_folder%\!repo_folder!" (
echo Cloning !repo_url! into !repo_folder!...
git clone !repo_url! "%extension_folder%\!repo_folder!"
cd "%extension_folder%\!repo_folder!"
git checkout !latest_commit!
cd %cd%
) else (
echo Updating !repo_folder! to the latest commit...
cd "%extension_folder%\!repo_folder!"
git fetch origin
git checkout !latest_commit!
cd %cd%
)
)
)
echo Version sync complete.
goto :eof
:parse_options
set options=%*
if not "%options%" == "" (
for %%o in (%options%) do (
if "%%o" == "-p" (
call :pre_flight_check
exit /b
) else if "%%o" == "--pre-flight" (
call :pre_flight_check
exit /b
) else if "%%o" == "-s" (
call :version_sync "extensions"
exit /b
) else if "%%o" == "--version-sync" (
call :version_sync "extensions"
exit /b
) else if "%%o" == "-h" (
call :show_help
exit /b
) else if "%%o" == "--help" (
call :show_help
exit /b
) else (
echo Unknown option: %%o
)
)
) else (
call :show_help
)
goto :eof
call :parse_options %*

View File

@ -1,3 +1,4 @@
boto3>=1.26.28
requests
urllib
urllib
psutil==5.9.5

273
test/test_windows.py Normal file
View File

@ -0,0 +1,273 @@
import os
import shutil
import filecmp
import unittest
import errno
import stat
import psutil
import sys
# append the path to the utils module
# sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from windows import tar, rm, cp, mv, df
class TestTarFunction(unittest.TestCase):
def setUp(self):
self.test_dir = 'test_tar_function'
try:
os.makedirs(self.test_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
os.chdir(self.test_dir)
with open('file1.txt', 'w') as f:
f.write('This is file1.')
with open('file2.txt', 'w') as f:
f.write('This is file2.')
def tearDown(self):
os.chdir('..')
shutil.rmtree(self.test_dir)
def test_tar_function(self):
# Test creating a new archive with file list as input
tar(mode='c', archive='archive.tar', sfiles=['file1.txt', 'file2.txt'])
# Verify the archive was created
self.assertTrue(os.path.exists('archive.tar'))
# Create a folder for extracted files
try:
os.makedirs('extracted')
except OSError as e:
if e.errno != errno.EEXIST:
raise
# Test extracting files from the archive in test
# assemble the path to the archive
archive_path = os.path.join(os.getcwd(), 'archive.tar')
tar(mode='x', archive=archive_path, change_dir='extracted')
# Verify the files were extracted
self.assertTrue(os.path.exists('extracted/file1.txt'))
self.assertTrue(os.path.exists('extracted/file2.txt'))
# Compare the original and extracted files
self.assertTrue(filecmp.cmp('file1.txt', 'extracted/file1.txt'))
self.assertTrue(filecmp.cmp('file2.txt', 'extracted/file2.txt'))
# Test creating a folder name as input
tar(mode='c', archive='archive2.tar', sfiles='extracted')
# Verify the archive was created
self.assertTrue(os.path.exists('archive2.tar'))
# Create a folder for extracted files
try:
os.makedirs('extracted2')
except OSError as e:
if e.errno != errno.EEXIST:
raise
# Test extracting files from the archive in test
# assemble the path to the archive
archive_path = os.path.join(os.getcwd(), 'archive2.tar')
tar(mode='x', archive=archive_path, change_dir='extracted2')
# Verify the files were extracted
self.assertTrue(os.path.exists('extracted2/extracted/file1.txt'))
self.assertTrue(os.path.exists('extracted2/extracted/file2.txt'))
# Compare the original and extracted files
self.assertTrue(filecmp.cmp('file1.txt', 'extracted2/extracted/file1.txt'))
self.assertTrue(filecmp.cmp('file2.txt', 'extracted2/extracted/file2.txt'))
class TestRmFunction(unittest.TestCase):
def setUp(self):
self.test_dir = 'test_rm_function'
os.makedirs(self.test_dir)
os.chdir(self.test_dir)
with open('file1.txt', 'w') as f:
f.write('This is file1.')
os.makedirs('dir1')
with open('dir1/file2.txt', 'w') as f:
f.write('This is file2 inside dir1.')
def tearDown(self):
os.chdir('..')
shutil.rmtree(self.test_dir)
def test_rm_file(self):
rm('file1.txt')
self.assertFalse(os.path.exists('file1.txt'))
def test_rm_directory_without_recursive(self):
with self.assertRaises(ValueError):
rm('dir1')
def test_rm_directory_with_recursive(self):
rm('dir1', recursive=True)
self.assertFalse(os.path.exists('dir1'))
def test_rm_nonexistent_file_without_force(self):
with self.assertRaises(ValueError):
rm('nonexistent_file.txt')
def test_rm_nonexistent_file_with_force(self):
try:
rm('nonexistent_file.txt', force=True)
except ValueError:
self.fail("rm() raised ValueError unexpectedly with force=True")
class TestCpFunction(unittest.TestCase):
def setUp(self):
self.test_dir = 'test_cp_function'
os.makedirs(self.test_dir)
os.chdir(self.test_dir)
with open('file1.txt', 'w') as f:
f.write('This is file1.')
os.makedirs('dir1')
with open('dir1/file2.txt', 'w') as f:
f.write('This is file2 inside dir1.')
def tearDown(self):
os.chdir('..')
shutil.rmtree(self.test_dir)
def test_cp_file(self):
cp('file1.txt', 'file1_copy.txt')
self.assertTrue(os.path.exists('file1_copy.txt'))
def test_cp_directory_without_recursive(self):
with self.assertRaises(ValueError):
cp('dir1', 'dir1_copy')
def test_cp_directory_with_recursive(self):
cp('dir1', 'dir1_copy', recursive=True)
self.assertTrue(os.path.exists('dir1_copy'))
self.assertTrue(os.path.exists('dir1_copy/file2.txt'))
def test_cp_dereference_symlink(self):
os.symlink('file1.txt', 'file1_symlink.txt')
cp('file1_symlink.txt', 'file1_dereferenced.txt', dereference=True)
self.assertTrue(os.path.exists('file1_dereferenced.txt'))
def test_cp_preserve_file_metadata(self):
os.chmod('file1.txt', stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
cp('file1.txt', 'file1_preserved.txt', preserve=True)
src_stat = os.stat('file1.txt')
dst_stat = os.stat('file1_preserved.txt')
self.assertEqual(src_stat.st_mode, dst_stat.st_mode)
def test_cp_not_preserve_file_metadata(self):
os.chmod('file1.txt', stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
cp('file1.txt', 'file1_not_preserved.txt', preserve=False)
# Modify the file mode of the destination file
os.chmod('file1_not_preserved.txt', stat.S_IRUSR | stat.S_IWUSR)
src_stat = os.stat('file1.txt')
dst_stat = os.stat('file1_not_preserved.txt')
self.assertNotEqual(src_stat.st_mode, dst_stat.st_mode)
class TestMvFunction(unittest.TestCase):
def setUp(self):
self.test_dir = 'test_mv_function'
os.makedirs(self.test_dir)
os.chdir(self.test_dir)
with open('file1.txt', 'w') as f:
f.write('This is file1.')
os.makedirs('dir1')
with open('dir1/file2.txt', 'w') as f:
f.write('This is file2 inside dir1.')
def tearDown(self):
os.chdir('..')
shutil.rmtree(self.test_dir)
def test_rename_file(self):
mv('file1.txt', 'file1_renamed.txt')
self.assertFalse(os.path.exists('file1.txt'))
self.assertTrue(os.path.exists('file1_renamed.txt'))
def test_move_file_to_new_directory(self):
os.makedirs('new_directory')
mv('file1.txt', 'new_directory/file1.txt')
self.assertFalse(os.path.exists('file1.txt'))
self.assertTrue(os.path.exists('new_directory/file1.txt'))
def test_move_directory(self):
os.makedirs('destination_directory')
mv('dir1', 'destination_directory/dir1')
self.assertFalse(os.path.exists('dir1'))
self.assertTrue(os.path.exists('destination_directory/dir1'))
def test_force_move_overwrite_file(self):
with open('existing_destination_file.txt', 'w') as f:
f.write('This is the existing destination file.')
mv('file1.txt', 'existing_destination_file.txt', force=True)
self.assertFalse(os.path.exists('file1.txt'))
self.assertTrue(os.path.exists('existing_destination_file.txt'))
def test_force_move_overwrite_directory(self):
os.makedirs('existing_destination_directory')
with open('existing_destination_directory/file3.txt', 'w') as f:
f.write('This is file3 inside the existing destination directory.')
mv('dir1', 'existing_destination_directory', force=True)
self.assertFalse(os.path.exists('dir1'))
self.assertTrue(os.path.exists('existing_destination_directory'))
self.assertTrue(os.path.exists('existing_destination_directory/file2.txt'))
self.assertFalse(os.path.exists('existing_destination_directory/file3.txt'))
def test_move_nonexistent_file(self):
with self.assertRaises(FileNotFoundError):
mv('nonexistent_file.txt', 'some_destination.txt')
def test_move_file_to_existing_destination_without_force(self):
with open('existing_destination_file.txt', 'w') as f:
f.write('This is the existing destination file.')
with self.assertRaises(FileExistsError):
mv('file1.txt', 'existing_destination_file.txt')
class TestDfFunction(unittest.TestCase):
def test_df_default_options(self):
filesystems = df()
for filesystem in filesystems:
self.assertIsInstance(filesystem['filesystem'], str)
self.assertIsInstance(filesystem['total'], str)
self.assertIsInstance(filesystem['used'], str)
self.assertIsInstance(filesystem['free'], str)
self.assertIsInstance(filesystem['percent'], float)
self.assertIsInstance(filesystem['mountpoint'], str)
def test_df_show_all(self):
filesystems_all = df(show_all=True)
filesystems_default = df()
self.assertGreaterEqual(len(filesystems_all), len(filesystems_default))
def test_df_human_readable(self):
filesystems = df(human_readable=True)
for filesystem in filesystems:
self.assertIsInstance(filesystem['filesystem'], str)
self.assertIsInstance(filesystem['total'], str)
self.assertIsInstance(filesystem['used'], str)
self.assertIsInstance(filesystem['free'], str)
self.assertIsInstance(filesystem['percent'], float)
self.assertIsInstance(filesystem['mountpoint'], str)
self.assertTrue(filesystem['total'][-1] in ['B', 'K', 'M', 'G', 'T', 'P'])
self.assertTrue(filesystem['used'][-1] in ['B', 'K', 'M', 'G', 'T', 'P'])
self.assertTrue(filesystem['free'][-1] in ['B', 'K', 'M', 'G', 'T', 'P'])
if __name__ == '__main__':
unittest.main()

208
test/windows.py Normal file
View File

@ -0,0 +1,208 @@
import os
import tarfile
import shutil
from pathlib import Path
import psutil
def tar(mode, archive, sfiles=None, verbose=False, change_dir=None):
"""
Description:
Create or extract a tar archive.
Args:
mode: 'c' for create or 'x' for extract
archive: the archive file name
files: a list of files to add to the archive (when creating) or extract (when extracting); None to extract all files
verbose: whether to print the names of the files as they are being processed
change_dir: the directory to change to before performing any other operations; None to use the current directory
Usage:
# Create a new archive
tar(mode='c', archive='archive.tar', sfiles=['file1.txt', 'file2.txt'])
# Extract files from an archive
tar(mode='x', archive='archive.tar')
# Create a new archive with verbose mode and input directory
tar(mode='c', archive='archive.tar', sfiles='./some_directory', verbose=True)
# Extract files from an archive with verbose mode and change directory
tar(mode='x', archive='archive.tar', verbose=True, change_dir='./some_directory')
"""
if mode == 'c':
# os.chdir(change_dir)
with tarfile.open(archive, mode='w') as tar:
# check if input option file is a list or string
if isinstance(sfiles, list):
for file in sfiles:
if verbose:
print(f"Adding {file} to {archive}")
tar.add(file)
# take it as a folder name string
else:
for folder_path, subfolders, files in os.walk(sfiles):
for file in files:
if verbose:
print(f"Adding {os.path.join(folder_path, file)} to {archive}")
tar.add(os.path.join(folder_path, file))
elif mode == 'x':
with tarfile.open(archive, mode='r') as tar:
# sfiles is set to all files in the archive if not specified
if not sfiles:
sfiles = tar.getnames()
for file in sfiles:
if verbose:
print(f"Extracting {file} from {archive}")
# extra to specified directory
if change_dir:
tar.extract(file, path=change_dir)
else:
tar.extract(file)
def rm(path, force=False, recursive=False):
"""
Description:
Remove a file or directory.
Args:
path (str): Path of the file or directory to remove.
force (bool): If True, ignore non-existent files and errors. Default is False.
recursive (bool): If True, remove directories and their contents recursively. Default is False.
Usage:
# Remove the file
rm(dst)
# Remove a directory recursively
rm("directory/path", recursive=True)
"""
path_obj = Path(path)
try:
if path_obj.is_file() or (path_obj.is_symlink() and not path_obj.is_dir()):
path_obj.unlink()
elif path_obj.is_dir() and recursive:
shutil.rmtree(path)
elif path_obj.is_dir():
raise ValueError("Cannot remove directory without recursive=True")
else:
raise ValueError("File or directory does not exist")
except Exception as e:
if not force:
raise e
def cp(src, dst, recursive=False, dereference=False, preserve=True):
"""
Description:
Copy a file or directory from source path to destination path.
Args:
src (str): Source file or directory path.
dst (str): Destination file or directory path.
recursive (bool): If True, copy directory and its contents recursively. Default is False.
dereference (bool): If True, always dereference symbolic links. Default is False.
preserve (bool): If True, preserve file metadata. Default is True.
Usage:
src = "source/file/path.txt"
dst = "destination/file/path.txt"
# Copy the file
cp(src, dst)
# Copy a directory recursively and dereference symlinks
cp("source/directory", "destination/directory", recursive=True, dereference=True)
"""
src_path = Path(src)
dst_path = Path(dst)
if dereference:
src_path = src_path.resolve()
if src_path.is_dir() and recursive:
if preserve:
shutil.copytree(src_path, dst_path, copy_function=shutil.copy2, symlinks=not dereference)
else:
shutil.copytree(src_path, dst_path, symlinks=not dereference)
elif src_path.is_file():
if preserve:
shutil.copy2(src_path, dst_path)
else:
shutil.copy(src_path, dst_path)
else:
raise ValueError("Source must be a file or a directory with recursive=True")
def mv(src, dest, force=False):
"""
Description:
Move or rename files and directories.
Args:
src (str): Source file or directory path.
dest (str): Destination file or directory path.
force (bool): If True, overwrite the destination if it exists. Default is False.
Usage:
# Rename a file
mv('old_name.txt', 'new_name.txt')
# Move a file to a new directory
mv('file.txt', 'new_directory/file.txt')
# Move a directory to another directory
mv('source_directory', 'destination_directory')
# Force move (overwrite) a file or directory
mv('source_file.txt', 'existing_destination_file.txt', force=True)
"""
src_path = Path(src)
dest_path = Path(dest)
if src_path.exists():
if dest_path.exists() and not force:
raise FileExistsError(f"Destination path '{dest}' already exists and 'force' is not set")
else:
if dest_path.is_file():
dest_path.unlink()
elif dest_path.is_dir():
shutil.rmtree(dest_path)
if src_path.is_file() or src_path.is_dir():
shutil.move(src, dest)
else:
raise FileNotFoundError(f"Source path '{src}' does not exist")
def format_size(size, human_readable):
if human_readable:
for unit in ['B', 'K', 'M', 'G', 'T', 'P']:
if size < 1024:
return f"{size:.1f}{unit}"
size /= 1024
else:
return str(size)
def df(show_all=False, human_readable=False):
"""
Description:
Get disk usage statistics.
Args:
show_all (bool): If True, include all filesystems. Default is False.
human_readable (bool): If True, format sizes in human readable format. Default is False.
Usage:
filesystems = df(show_all=True, human_readable=True)
for filesystem in filesystems:
print(f"Filesystem: {filesystem['filesystem']}")
print(f"Total: {filesystem['total']}")
print(f"Used: {filesystem['used']}")
print(f"Free: {filesystem['free']}")
print(f"Percent: {filesystem['percent']}%")
print(f"Mountpoint: {filesystem['mountpoint']}")
"""
partitions = psutil.disk_partitions(all=show_all)
result = []
for partition in partitions:
usage = psutil.disk_usage(partition.mountpoint)
partition_info = {
'filesystem': partition.device,
'total': format_size(usage.total, human_readable),
'used': format_size(usage.used, human_readable),
'free': format_size(usage.free, human_readable),
'percent': usage.percent,
'mountpoint': partition.mountpoint,
}
result.append(partition_info)
return result

228
utils.py
View File

@ -13,6 +13,10 @@ sys.path.append(os.getcwd())
# from modules.timer import Timer
import tarfile
import shutil
from pathlib import Path
import psutil
class ModelsRef:
def __init__(self):
self.models_ref = {}
@ -84,7 +88,8 @@ def upload_folder_to_s3(local_folder_path, bucket_name, s3_folder_path):
def upload_folder_to_s3_by_tar(local_folder_path, bucket_name, s3_folder_path):
tar_name = f"{os.path.basename(local_folder_path)}.tar"
os.system(f'tar cvf {tar_name} {local_folder_path}')
# os.system(f'tar cvf {tar_name} {local_folder_path}')
tar(mode='c', archive=tar_name, sfiles=local_folder_path, verbose=True)
# tar = tarfile.open(tar_path, "w:gz")
# for root, dirs, files in os.walk(local_folder_path):
# for file in files:
@ -93,7 +98,9 @@ def upload_folder_to_s3_by_tar(local_folder_path, bucket_name, s3_folder_path):
# tar.close()
s3_client = boto3.client('s3')
s3_client.upload_file(tar_name, bucket_name, os.path.join(s3_folder_path, tar_name))
os.system(f"rm {tar_name}")
# os.system(f"rm {tar_name}")
rm(tar_name, recursive=True)
def upload_file_to_s3(file_name, bucket, directory=None, object_name=None):
# If S3 object_name was not specified, use file_name
@ -149,7 +156,7 @@ def download_folder_from_s3(bucket_name, s3_folder_path, local_folder_path):
s3_resource = boto3.resource('s3')
bucket = s3_resource.Bucket(bucket_name)
for obj in bucket.objects.filter(Prefix=s3_folder_path):
obj_dirname = "/".join(os.path.dirname(obj.key).split("/")[1:])
obj_dirname = os.sep.join(os.path.dirname(obj.key).split("/")[1:])
obj_basename = os.path.basename(obj.key)
local_sub_folder_path = os.path.join(local_folder_path, obj_dirname)
if not os.path.exists(local_sub_folder_path):
@ -161,11 +168,13 @@ def download_folder_from_s3_by_tar(bucket_name, s3_tar_path, local_tar_path, tar
s3_client = boto3.client('s3')
s3_client.download_file(bucket_name, s3_tar_path, local_tar_path)
# tar_name = os.path.basename(s3_tar_path)
os.system(f"tar xvf {local_tar_path} -C {target_dir}")
# os.system(f"tar xvf {local_tar_path} -C {target_dir}")
tar(mode='x', archive=local_tar_path, verbose=True, change_dir=target_dir)
# tar = tarfile.open(local_tar_path, "r")
# tar.extractall()
# tar.close()
os.system(f"rm {local_tar_path}")
# os.system(f"rm {local_tar_path}")
rm(local_tar_path, recursive=True)
def download_file_from_s3(bucket_name, s3_file_path, local_file_path):
@ -235,6 +244,215 @@ def get_variable_from_json(variable_name, filename='sagemaker_ui.json'):
return variable_value
"""
Description: Below functions are used to replace existing shell command implementation with os.system method, which is not os agonostic and not recommended.
"""
def tar(mode, archive, sfiles=None, verbose=False, change_dir=None):
"""
Description:
Create or extract a tar archive.
Args:
mode: 'c' for create or 'x' for extract
archive: the archive file name
files: a list of files to add to the archive (when creating) or extract (when extracting); None to extract all files
verbose: whether to print the names of the files as they are being processed
change_dir: the directory to change to before performing any other operations; None to use the current directory
Usage:
# Create a new archive
tar(mode='c', archive='archive.tar', sfiles=['file1.txt', 'file2.txt'])
# Extract files from an archive
tar(mode='x', archive='archive.tar')
# Create a new archive with verbose mode and input directory
tar(mode='c', archive='archive.tar', sfiles='./some_directory', verbose=True)
# Extract files from an archive with verbose mode and change directory
tar(mode='x', archive='archive.tar', verbose=True, change_dir='./some_directory')
"""
if mode == 'c':
# os.chdir(change_dir)
with tarfile.open(archive, mode='w') as tar:
# check if input option file is a list or string
if isinstance(sfiles, list):
for file in sfiles:
if verbose:
print(f"Adding {file} to {archive}")
tar.add(file)
# take it as a folder name string
else:
for folder_path, subfolders, files in os.walk(sfiles):
for file in files:
if verbose:
print(f"Adding {os.path.join(folder_path, file)} to {archive}")
tar.add(os.path.join(folder_path, file))
elif mode == 'x':
with tarfile.open(archive, mode='r') as tar:
# sfiles is set to all files in the archive if not specified
if not sfiles:
sfiles = tar.getnames()
for file in sfiles:
if verbose:
print(f"Extracting {file} from {archive}")
# extra to specified directory
if change_dir:
tar.extract(file, path=change_dir)
else:
tar.extract(file)
def rm(path, force=False, recursive=False):
"""
Description:
Remove a file or directory.
Args:
path (str): Path of the file or directory to remove.
force (bool): If True, ignore non-existent files and errors. Default is False.
recursive (bool): If True, remove directories and their contents recursively. Default is False.
Usage:
# Remove the file
rm(dst)
# Remove a directory recursively
rm("directory/path", recursive=True)
"""
path_obj = Path(path)
try:
if path_obj.is_file() or (path_obj.is_symlink() and not path_obj.is_dir()):
path_obj.unlink()
elif path_obj.is_dir() and recursive:
shutil.rmtree(path)
elif path_obj.is_dir():
raise ValueError("Cannot remove directory without recursive=True")
else:
raise ValueError("File or directory does not exist")
except Exception as e:
if not force:
raise e
def cp(src, dst, recursive=False, dereference=False, preserve=True):
"""
Description:
Copy a file or directory from source path to destination path.
Args:
src (str): Source file or directory path.
dst (str): Destination file or directory path.
recursive (bool): If True, copy directory and its contents recursively. Default is False.
dereference (bool): If True, always dereference symbolic links. Default is False.
preserve (bool): If True, preserve file metadata. Default is True.
Usage:
src = "source/file/path.txt"
dst = "destination/file/path.txt"
# Copy the file
cp(src, dst)
# Copy a directory recursively and dereference symlinks
cp("source/directory", "destination/directory", recursive=True, dereference=True)
"""
src_path = Path(src)
dst_path = Path(dst)
try:
if dereference:
src_path = src_path.resolve()
if src_path.is_dir() and recursive:
if preserve:
shutil.copytree(src_path, dst_path, copy_function=shutil.copy2, symlinks=not dereference)
else:
shutil.copytree(src_path, dst_path, symlinks=not dereference)
elif src_path.is_file():
if preserve:
shutil.copy2(src_path, dst_path)
else:
shutil.copy(src_path, dst_path)
else:
raise ValueError("Source must be a file or a directory with recursive=True")
except shutil.SameFileError:
print("Source and destination represents the same file.")
def mv(src, dest, force=False):
"""
Description:
Move or rename files and directories.
Args:
src (str): Source file or directory path.
dest (str): Destination file or directory path.
force (bool): If True, overwrite the destination if it exists. Default is False.
Usage:
# Rename a file
mv('old_name.txt', 'new_name.txt')
# Move a file to a new directory
mv('file.txt', 'new_directory/file.txt')
# Move a directory to another directory
mv('source_directory', 'destination_directory')
# Force move (overwrite) a file or directory
mv('source_file.txt', 'existing_destination_file.txt', force=True)
"""
src_path = Path(src)
dest_path = Path(dest)
if src_path.exists():
if dest_path.exists() and not force:
raise FileExistsError(f"Destination path '{dest}' already exists and 'force' is not set")
else:
if dest_path.is_file():
dest_path.unlink()
elif dest_path.is_dir():
shutil.rmtree(dest_path)
if src_path.is_file() or src_path.is_dir():
shutil.move(src, dest)
else:
raise FileNotFoundError(f"Source path '{src}' does not exist")
def format_size(size, human_readable):
if human_readable:
for unit in ['B', 'K', 'M', 'G', 'T', 'P']:
if size < 1024:
return f"{size:.1f}{unit}"
size /= 1024
else:
return str(size)
def df(show_all=False, human_readable=False):
"""
Description:
Get disk usage statistics.
Args:
show_all (bool): If True, include all filesystems. Default is False.
human_readable (bool): If True, format sizes in human readable format. Default is False.
Usage:
filesystems = df(show_all=True, human_readable=True)
for filesystem in filesystems:
print(f"Filesystem: {filesystem['filesystem']}")
print(f"Total: {filesystem['total']}")
print(f"Used: {filesystem['used']}")
print(f"Free: {filesystem['free']}")
print(f"Percent: {filesystem['percent']}%")
print(f"Mountpoint: {filesystem['mountpoint']}")
"""
partitions = psutil.disk_partitions(all=show_all)
result = []
for partition in partitions:
usage = psutil.disk_usage(partition.mountpoint)
partition_info = {
'filesystem': partition.device,
'total': format_size(usage.total, human_readable),
'used': format_size(usage.used, human_readable),
'free': format_size(usage.free, human_readable),
'percent': usage.percent,
'mountpoint': partition.mountpoint,
}
result.append(partition_info)
return result
if __name__ == '__main__':
import sys