diff --git a/aws_extension/sagemaker_ui.py b/aws_extension/sagemaker_ui.py index 60132e82..ae233695 100644 --- a/aws_extension/sagemaker_ui.py +++ b/aws_extension/sagemaker_ui.py @@ -23,6 +23,8 @@ from datetime import datetime import math import re +from utils import cp, tar, rm + inference_job_dropdown = None textual_inversion_dropdown = None hyperNetwork_dropdown = None @@ -196,7 +198,7 @@ def get_inference_job_list(txt2img_type_checkbox=True, img2img_type_checkbox=Tru except Exception as e: print("Exception occurred when fetching inference_job_ids") return gr.Dropdown.update(choices=[]) - + @@ -338,7 +340,7 @@ def refresh_all_models(): ckpt_type = ckpt["type"] checkpoint_info[ckpt_type] = {} for ckpt_name in ckpt["name"]: - ckpt_s3_pos = f"{ckpt['s3Location']}/{ckpt_name.split('/')[-1]}" + ckpt_s3_pos = f"{ckpt['s3Location']}/{ckpt_name.split(os.sep)[-1]}" checkpoint_info[ckpt_type][ckpt_name] = ckpt_s3_pos except Exception as e: print(f"Error refresh all models: {e}") @@ -355,7 +357,7 @@ def sagemaker_upload_model_s3(sd_checkpoints_path, textual_inversion_path, lora_ if lp == "" or not lp: continue print(f"lp is {lp}") - model_name = lp.split("/")[-1] + model_name = lp.split(os.sep)[-1] exist_model_list = list(checkpoint_info[rp].keys()) @@ -403,23 +405,28 @@ def sagemaker_upload_model_s3(sd_checkpoints_path, textual_inversion_path, lora_ s3_signed_urls_resp = json_response["s3PresignUrl"][local_tar_path] # Upload src model to S3. if rp != "embeddings" : - local_model_path_in_repo = f'models/{rp}/{model_name}' + local_model_path_in_repo = os.sep.join(['models', rp, model_name]) else: - local_model_path_in_repo = f'{rp}/{model_name}' + local_model_path_in_repo = os.sep.join([rp, model_name]) #local_tar_path = f'{model_name}.tar' print("Pack the model file.") - os.system(f"cp -f {lp} {local_model_path_in_repo}") + # os.system(f"cp -f {lp} {local_model_path_in_repo}") + cp(lp, local_model_path_in_repo, recursive=True) if rp == "Stable-diffusion": model_yaml_name = model_name.split('.')[0] + ".yaml" - local_model_yaml_path = "/".join(lp.split("/")[:-1]) + f"/{model_yaml_name}" - local_model_yaml_path_in_repo = f"models/{rp}/{model_yaml_name}" + local_model_yaml_path = os.sep.join([*lp.split(os.sep)[:-1], model_yaml_name]) + local_model_yaml_path_in_repo = os.sep.join(["models", rp, model_yaml_name]) if os.path.isfile(local_model_yaml_path): - os.system(f"cp -f {local_model_yaml_path} {local_model_yaml_path_in_repo}") - os.system(f"tar cvf {local_tar_path} {local_model_path_in_repo} {local_model_yaml_path_in_repo}") + # os.system(f"cp -f {local_model_yaml_path} {local_model_yaml_path_in_repo}") + # os.system(f"tar cvf {local_tar_path} {local_model_path_in_repo} {local_model_yaml_path_in_repo}") + cp(local_model_yaml_path, local_model_yaml_path_in_repo, recursive=True) + tar(mode='c', archive=local_tar_path, sfiles=[local_model_path_in_repo, local_model_yaml_path_in_repo], verbose=True) else: - os.system(f"tar cvf {local_tar_path} {local_model_path_in_repo}") + # os.system(f"tar cvf {local_tar_path} {local_model_path_in_repo}") + tar(mode='c', archive=local_tar_path, sfiles=[local_model_path_in_repo], verbose=True) else: - os.system(f"tar cvf {local_tar_path} {local_model_path_in_repo}") + # os.system(f"tar cvf {local_tar_path} {local_model_path_in_repo}") + tar(mode='c', archive=local_tar_path, sfiles=[local_model_path_in_repo], verbose=True) #upload_file_to_s3_by_presign_url(local_tar_path, s3_presigned_url) multiparts_tags = upload_multipart_files_to_s3_by_signed_url( local_tar_path, @@ -439,7 +446,8 @@ def sagemaker_upload_model_s3(sd_checkpoints_path, textual_inversion_path, lora_ log = f"\n finish upload {local_tar_path} to {s3_base}" - os.system(f"rm {local_tar_path}") + # os.system(f"rm {local_tar_path}") + rm(local_tar_path, recursive=True) except Exception as e: print(f"fail to upload model {lp}, error: {e}") @@ -753,7 +761,7 @@ def fake_gan(selected_value: str ): if inference_job_status == 'inprogress': return [], [], plaintext_to_html('inference still in progress') - if inference_job_taskType in ["txt2img", "img2img"]: + if inference_job_taskType in ["txt2img", "img2img"]: prompt_txt = '' images = get_inference_job_image_output(inference_job_id) image_list = [] @@ -786,7 +794,7 @@ def fake_gan(selected_value: str ): prompt_txt = caption image_list = [] # Return an empty list if selected_value is None json_list = [] - info_text = '' + info_text = '' infotexts = '' else: prompt_txt = '' @@ -854,9 +862,9 @@ def create_ui(is_img2img): with gr.Row(): global sagemaker_endpoint sagemaker_endpoint = gr.Dropdown(sagemaker_endpoints, - label="Select Cloud SageMaker Endpoint", - elem_id="sagemaker_endpoint_dropdown" - ) + label="Select Cloud SageMaker Endpoint", + elem_id="sagemaker_endpoint_dropdown" + ) modules.ui.create_refresh_button(sagemaker_endpoint, update_sagemaker_endpoints, lambda: {"choices": sagemaker_endpoints}, "refresh_sagemaker_endpoints") with gr.Row(): diff --git a/build_scripts/training/sagemaker_entrypoint.py b/build_scripts/training/sagemaker_entrypoint.py index ec707aa8..74a38264 100644 --- a/build_scripts/training/sagemaker_entrypoint.py +++ b/build_scripts/training/sagemaker_entrypoint.py @@ -22,6 +22,8 @@ os.environ['IGNORE_CMD_ARGS_ERRORS'] = "" from dreambooth.ui_functions import start_training from dreambooth.shared import status +from utls import tar, mv + def sync_status_from_s3_json(bucket_name, webui_status_file_path, sagemaker_status_file_path): while True: time.sleep(1) @@ -98,19 +100,16 @@ def upload_model_to_s3_v2(model_name, s3_output_path, model_type): yaml = os.path.join(root, f"{ckpt_name}.yaml") output_tar = file tar_command = f"tar cvf {output_tar} {safetensors} {yaml}" - print(tar_command) - os.system(tar_command) - logger.info(f"Upload check point to s3 {output_tar} {output_bucket_name} {s3_output_path}") - print(f"Upload check point to s3 {output_tar} {output_bucket_name} {s3_output_path}") - upload_file_to_s3(output_tar, output_bucket_name, os.path.join(s3_output_path, model_name)) elif model_type == "Lora": output_tar = file tar_command = f"tar cvf {output_tar} {safetensors}" - print(tar_command) - os.system(tar_command) - logger.info(f"Upload check point to s3 {output_tar} {output_bucket_name} {s3_output_path}") - print(f"Upload check point to s3 {output_tar} {output_bucket_name} {s3_output_path}") - upload_file_to_s3(output_tar, output_bucket_name, s3_output_path) + + print(tar_command) + # os.system(tar_command) + tar(mode='c', archive=output_tar, sfiles=[safetensors, yaml], verbose=True) + logger.info(f"Upload check point to s3 {output_tar} {output_bucket_name} {s3_output_path}") + print(f"Upload check point to s3 {output_tar} {output_bucket_name} {s3_output_path}") + upload_file_to_s3(output_tar, output_bucket_name, s3_output_path) def download_data(data_list, s3_data_path_list, s3_input_path): for data, data_tar in zip(data_list, s3_data_path_list): @@ -145,7 +144,8 @@ def prepare_for_training(s3_model_path, model_name, s3_input_path, data_tar_list download_db_config_path = f"models/sagemaker_dreambooth/{model_name}/db_config_cloud.json" target_db_config_path = f"models/dreambooth/{model_name}/db_config.json" logger.info(f"Move db_config to correct position {download_db_config_path} {target_db_config_path}") - os.system(f"mv {download_db_config_path} {target_db_config_path}") + # os.system(f"mv {download_db_config_path} {target_db_config_path}") + mv(download_db_config_path, target_db_config_path) with open(target_db_config_path) as db_config_file: db_config = json.load(db_config_file) data_list = [] @@ -171,7 +171,8 @@ def prepare_for_training_v2(s3_model_path, model_name, s3_input_path, s3_data_pa download_db_config_path = f"models/sagemaker_dreambooth/{model_name}/db_config_cloud.json" target_db_config_path = f"models/dreambooth/{model_name}/db_config.json" logger.info(f"Move db_config to correct position {download_db_config_path} {target_db_config_path}") - os.system(f"mv {download_db_config_path} {target_db_config_path}") + # os.system(f"mv {download_db_config_path} {target_db_config_path}") + mv(download_db_config_path, target_db_config_path) with open(target_db_config_path) as db_config_file: db_config = json.load(db_config_file) data_list = [] diff --git a/dreambooth_on_cloud/create_model.py b/dreambooth_on_cloud/create_model.py index 018898e0..2e9fbef5 100644 --- a/dreambooth_on_cloud/create_model.py +++ b/dreambooth_on_cloud/create_model.py @@ -10,8 +10,8 @@ import logging from modules import sd_models from utils import upload_multipart_files_to_s3_by_signed_url from utils import get_variable_from_json +from utils import tar import gradio as gr - logging.basicConfig(filename='sd-aws-ext.log', level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) @@ -144,7 +144,8 @@ def async_create_model_on_sagemaker( multiparts_tags=[] if not from_hub: print("Pack the model file.") - os.system(f"tar cvf {local_tar_path} {local_model_path}") + # os.system(f"tar cvf {local_tar_path} {local_model_path}") + tar(mode='c', archive=local_tar_path, sfiles=[local_model_path], verbose=True) s3_base = json_response["job"]["s3_base"] print(f"Upload to S3 {s3_base}") print(f"Model ID: {model_id}") diff --git a/dreambooth_on_cloud/train.py b/dreambooth_on_cloud/train.py index 4368242f..a01b2db4 100644 --- a/dreambooth_on_cloud/train.py +++ b/dreambooth_on_cloud/train.py @@ -9,6 +9,7 @@ import logging import shutil from utils import upload_file_to_s3_by_presign_url from utils import get_variable_from_json +from utils import tar, cp logging.basicConfig(filename='sd-aws-ext.log', level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) @@ -114,7 +115,8 @@ def async_prepare_for_training_on_sagemaker( url += "train" upload_files = [] db_config_tar = f"db_config.tar" - os.system(f"tar cvf {db_config_tar} {db_config_path}") + # os.system(f"tar cvf {db_config_tar} {db_config_path}") + tar(mode='c', archive=db_config_tar, sfiles=db_config_path, verbose=True) upload_files.append(db_config_tar) new_data_list = [] for data_path in data_path_list: @@ -125,7 +127,8 @@ def async_prepare_for_training_on_sagemaker( data_tar = f'data-{data_path.replace("/", "-").strip("-")}.tar' new_data_list.append(data_tar) print("Pack the data file.") - os.system(f"tar cf {data_tar} {data_path}") + # os.system(f"tar cf {data_tar} {data_path}") + tar(mode='c', archive=data_tar, sfiles=data_path, verbose=False) upload_files.append(data_tar) else: new_data_list.append(data_path) @@ -139,7 +142,8 @@ def async_prepare_for_training_on_sagemaker( new_class_data_list.append(class_data_tar) upload_files.append(class_data_tar) print("Pack the class data file.") - os.system(f"tar cf {class_data_tar} {class_data_path}") + # os.system(f"tar cf {class_data_tar} {class_data_path}") + tar(mode='c', archive=class_data_tar, sfiles=[class_data_path], verbose=False) else: new_class_data_list.append(class_data_path) payload = { diff --git a/pre-flight.bat b/pre-flight.bat new file mode 100644 index 00000000..a98d445d --- /dev/null +++ b/pre-flight.bat @@ -0,0 +1,112 @@ +@echo off +SETLOCAL ENABLEDELAYEDEXPANSION + +set INITIAL_SUPPORT_COMMIT_ROOT=89f9faa6 +set INITIAL_SUPPORT_COMMIT_CONTROLNET=7c674f83 +set INITIAL_SUPPORT_COMMIT_DREAMBOOTH=926ae204 + +set REPO_URL_LIST="https://github.com/Mikubill/sd-webui-controlnet.git https://github.com/d8ahazard/sd_dreambooth_extension.git" +set REPO_FOLDER_LIST="sd-webui-controlnet sd_dreambooth_extension" + +:show_help +echo Usage: %~nx0 -p/--pre-flight -s/--version-sync +goto :eof + +:get_supported_commit_list +set repo_url=%1 +set initial_support_commit=%2 +set latest_commit=%3 +for /f "tokens=*" %%i in ('git rev-list --topo-order %initial_support_commit%^..%latest_commit%') do echo %%i +goto :eof + +:get_latest_commit_id +set repo_url=%1 +for /f "tokens=1" %%i in ('git ls-remote "%repo_url%" HEAD ^| findstr /b "[0-9a-f]"') do set latest_commit_id=%%i +echo %latest_commit_id% +goto :eof + +:pre_flight_check +echo Start pre-flight check for WebUI... +call :get_latest_commit_id "https://github.com/AUTOMATIC1111/stable-diffusion-webui.git" +set LATEST_ROOT_COMMIT=%latest_commit_id% +echo Supported commits for WebUI: +call :get_supported_commit_list "https://github.com/AUTOMATIC1111/stable-diffusion-webui.git" "%INITIAL_SUPPORT_COMMIT_ROOT%" "%LATEST_ROOT_COMMIT%" +set SUPPORTED_ROOT_COMMITS=%supported_commit_list% +for /f "tokens=*" %%i in ('git -C ..\.. rev-parse HEAD') do set CUR_ROOT_COMMIT=%%i +echo Current commit id for WebUI: %CUR_ROOT_COMMIT% + +echo Pre-flight checks complete. + +goto :eof + +:version_sync +echo Start version sync for WebUI, make sure the extension folder is empty... + +set extension_folder=%1 +if not exist "%extension_folder%" ( + echo The extension folder does not exist: %extension_folder% + echo Please create it and run the script again. + goto :eof +) + +echo Syncing WebUI... + +for %%r in (%REPO_URL_LIST%) do ( + set repo_url=%%r + call :get_latest_commit_id !repo_url! + set latest_commit=!latest_commit_id! + + for %%f in (%REPO_FOLDER_LIST%) do ( + set repo_folder=%%f + if not exist "%extension_folder%\!repo_folder!" ( + echo Cloning !repo_url! into !repo_folder!... + git clone !repo_url! "%extension_folder%\!repo_folder!" + cd "%extension_folder%\!repo_folder!" + git checkout !latest_commit! + cd %cd% + ) else ( + echo Updating !repo_folder! to the latest commit... + cd "%extension_folder%\!repo_folder!" + git fetch origin + git checkout !latest_commit! + cd %cd% + ) + ) +) + +echo Version sync complete. + +goto :eof + +:parse_options +set options=%* +if not "%options%" == "" ( + for %%o in (%options%) do ( + if "%%o" == "-p" ( + call :pre_flight_check + exit /b + ) else if "%%o" == "--pre-flight" ( + call :pre_flight_check + exit /b + ) else if "%%o" == "-s" ( + call :version_sync "extensions" + exit /b + ) else if "%%o" == "--version-sync" ( + call :version_sync "extensions" + exit /b + ) else if "%%o" == "-h" ( + call :show_help + exit /b + ) else if "%%o" == "--help" ( + call :show_help + exit /b + ) else ( + echo Unknown option: %%o + ) + ) +) else ( + call :show_help +) +goto :eof + +call :parse_options %* \ No newline at end of file diff --git a/scripts/requirements.txt b/scripts/requirements.txt index 1f15719b..e59ca81b 100644 --- a/scripts/requirements.txt +++ b/scripts/requirements.txt @@ -1,3 +1,4 @@ boto3>=1.26.28 requests -urllib \ No newline at end of file +urllib +psutil==5.9.5 diff --git a/test/test_windows.py b/test/test_windows.py new file mode 100644 index 00000000..dd71fbf8 --- /dev/null +++ b/test/test_windows.py @@ -0,0 +1,273 @@ +import os +import shutil +import filecmp +import unittest +import errno +import stat +import psutil +import sys + +# append the path to the utils module +# sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + +from windows import tar, rm, cp, mv, df + +class TestTarFunction(unittest.TestCase): + def setUp(self): + self.test_dir = 'test_tar_function' + try: + os.makedirs(self.test_dir) + except OSError as e: + if e.errno != errno.EEXIST: + raise + os.chdir(self.test_dir) + + with open('file1.txt', 'w') as f: + f.write('This is file1.') + with open('file2.txt', 'w') as f: + f.write('This is file2.') + + def tearDown(self): + os.chdir('..') + shutil.rmtree(self.test_dir) + + def test_tar_function(self): + # Test creating a new archive with file list as input + tar(mode='c', archive='archive.tar', sfiles=['file1.txt', 'file2.txt']) + + # Verify the archive was created + self.assertTrue(os.path.exists('archive.tar')) + + # Create a folder for extracted files + try: + os.makedirs('extracted') + except OSError as e: + if e.errno != errno.EEXIST: + raise + + # Test extracting files from the archive in test + # assemble the path to the archive + archive_path = os.path.join(os.getcwd(), 'archive.tar') + tar(mode='x', archive=archive_path, change_dir='extracted') + + # Verify the files were extracted + self.assertTrue(os.path.exists('extracted/file1.txt')) + self.assertTrue(os.path.exists('extracted/file2.txt')) + + # Compare the original and extracted files + self.assertTrue(filecmp.cmp('file1.txt', 'extracted/file1.txt')) + self.assertTrue(filecmp.cmp('file2.txt', 'extracted/file2.txt')) + + # Test creating a folder name as input + tar(mode='c', archive='archive2.tar', sfiles='extracted') + + # Verify the archive was created + self.assertTrue(os.path.exists('archive2.tar')) + + # Create a folder for extracted files + try: + os.makedirs('extracted2') + except OSError as e: + if e.errno != errno.EEXIST: + raise + + # Test extracting files from the archive in test + # assemble the path to the archive + archive_path = os.path.join(os.getcwd(), 'archive2.tar') + tar(mode='x', archive=archive_path, change_dir='extracted2') + + # Verify the files were extracted + self.assertTrue(os.path.exists('extracted2/extracted/file1.txt')) + self.assertTrue(os.path.exists('extracted2/extracted/file2.txt')) + + # Compare the original and extracted files + self.assertTrue(filecmp.cmp('file1.txt', 'extracted2/extracted/file1.txt')) + self.assertTrue(filecmp.cmp('file2.txt', 'extracted2/extracted/file2.txt')) + +class TestRmFunction(unittest.TestCase): + def setUp(self): + self.test_dir = 'test_rm_function' + os.makedirs(self.test_dir) + os.chdir(self.test_dir) + + with open('file1.txt', 'w') as f: + f.write('This is file1.') + + os.makedirs('dir1') + with open('dir1/file2.txt', 'w') as f: + f.write('This is file2 inside dir1.') + + def tearDown(self): + os.chdir('..') + shutil.rmtree(self.test_dir) + + def test_rm_file(self): + rm('file1.txt') + self.assertFalse(os.path.exists('file1.txt')) + + def test_rm_directory_without_recursive(self): + with self.assertRaises(ValueError): + rm('dir1') + + def test_rm_directory_with_recursive(self): + rm('dir1', recursive=True) + self.assertFalse(os.path.exists('dir1')) + + def test_rm_nonexistent_file_without_force(self): + with self.assertRaises(ValueError): + rm('nonexistent_file.txt') + + def test_rm_nonexistent_file_with_force(self): + try: + rm('nonexistent_file.txt', force=True) + except ValueError: + self.fail("rm() raised ValueError unexpectedly with force=True") + +class TestCpFunction(unittest.TestCase): + def setUp(self): + self.test_dir = 'test_cp_function' + os.makedirs(self.test_dir) + os.chdir(self.test_dir) + + with open('file1.txt', 'w') as f: + f.write('This is file1.') + + os.makedirs('dir1') + with open('dir1/file2.txt', 'w') as f: + f.write('This is file2 inside dir1.') + + def tearDown(self): + os.chdir('..') + shutil.rmtree(self.test_dir) + + def test_cp_file(self): + cp('file1.txt', 'file1_copy.txt') + self.assertTrue(os.path.exists('file1_copy.txt')) + + def test_cp_directory_without_recursive(self): + with self.assertRaises(ValueError): + cp('dir1', 'dir1_copy') + + def test_cp_directory_with_recursive(self): + cp('dir1', 'dir1_copy', recursive=True) + self.assertTrue(os.path.exists('dir1_copy')) + self.assertTrue(os.path.exists('dir1_copy/file2.txt')) + + def test_cp_dereference_symlink(self): + os.symlink('file1.txt', 'file1_symlink.txt') + cp('file1_symlink.txt', 'file1_dereferenced.txt', dereference=True) + self.assertTrue(os.path.exists('file1_dereferenced.txt')) + + def test_cp_preserve_file_metadata(self): + os.chmod('file1.txt', stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) + cp('file1.txt', 'file1_preserved.txt', preserve=True) + src_stat = os.stat('file1.txt') + dst_stat = os.stat('file1_preserved.txt') + self.assertEqual(src_stat.st_mode, dst_stat.st_mode) + + def test_cp_not_preserve_file_metadata(self): + os.chmod('file1.txt', stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) + cp('file1.txt', 'file1_not_preserved.txt', preserve=False) + + # Modify the file mode of the destination file + os.chmod('file1_not_preserved.txt', stat.S_IRUSR | stat.S_IWUSR) + + src_stat = os.stat('file1.txt') + dst_stat = os.stat('file1_not_preserved.txt') + self.assertNotEqual(src_stat.st_mode, dst_stat.st_mode) + +class TestMvFunction(unittest.TestCase): + def setUp(self): + self.test_dir = 'test_mv_function' + os.makedirs(self.test_dir) + os.chdir(self.test_dir) + + with open('file1.txt', 'w') as f: + f.write('This is file1.') + + os.makedirs('dir1') + with open('dir1/file2.txt', 'w') as f: + f.write('This is file2 inside dir1.') + + def tearDown(self): + os.chdir('..') + shutil.rmtree(self.test_dir) + + def test_rename_file(self): + mv('file1.txt', 'file1_renamed.txt') + self.assertFalse(os.path.exists('file1.txt')) + self.assertTrue(os.path.exists('file1_renamed.txt')) + + def test_move_file_to_new_directory(self): + os.makedirs('new_directory') + mv('file1.txt', 'new_directory/file1.txt') + self.assertFalse(os.path.exists('file1.txt')) + self.assertTrue(os.path.exists('new_directory/file1.txt')) + + def test_move_directory(self): + os.makedirs('destination_directory') + mv('dir1', 'destination_directory/dir1') + self.assertFalse(os.path.exists('dir1')) + self.assertTrue(os.path.exists('destination_directory/dir1')) + + def test_force_move_overwrite_file(self): + with open('existing_destination_file.txt', 'w') as f: + f.write('This is the existing destination file.') + + mv('file1.txt', 'existing_destination_file.txt', force=True) + self.assertFalse(os.path.exists('file1.txt')) + self.assertTrue(os.path.exists('existing_destination_file.txt')) + + def test_force_move_overwrite_directory(self): + os.makedirs('existing_destination_directory') + with open('existing_destination_directory/file3.txt', 'w') as f: + f.write('This is file3 inside the existing destination directory.') + + mv('dir1', 'existing_destination_directory', force=True) + self.assertFalse(os.path.exists('dir1')) + self.assertTrue(os.path.exists('existing_destination_directory')) + self.assertTrue(os.path.exists('existing_destination_directory/file2.txt')) + self.assertFalse(os.path.exists('existing_destination_directory/file3.txt')) + + def test_move_nonexistent_file(self): + with self.assertRaises(FileNotFoundError): + mv('nonexistent_file.txt', 'some_destination.txt') + + def test_move_file_to_existing_destination_without_force(self): + with open('existing_destination_file.txt', 'w') as f: + f.write('This is the existing destination file.') + + with self.assertRaises(FileExistsError): + mv('file1.txt', 'existing_destination_file.txt') + +class TestDfFunction(unittest.TestCase): + def test_df_default_options(self): + filesystems = df() + for filesystem in filesystems: + self.assertIsInstance(filesystem['filesystem'], str) + self.assertIsInstance(filesystem['total'], str) + self.assertIsInstance(filesystem['used'], str) + self.assertIsInstance(filesystem['free'], str) + self.assertIsInstance(filesystem['percent'], float) + self.assertIsInstance(filesystem['mountpoint'], str) + + def test_df_show_all(self): + filesystems_all = df(show_all=True) + filesystems_default = df() + self.assertGreaterEqual(len(filesystems_all), len(filesystems_default)) + + def test_df_human_readable(self): + filesystems = df(human_readable=True) + for filesystem in filesystems: + self.assertIsInstance(filesystem['filesystem'], str) + self.assertIsInstance(filesystem['total'], str) + self.assertIsInstance(filesystem['used'], str) + self.assertIsInstance(filesystem['free'], str) + self.assertIsInstance(filesystem['percent'], float) + self.assertIsInstance(filesystem['mountpoint'], str) + self.assertTrue(filesystem['total'][-1] in ['B', 'K', 'M', 'G', 'T', 'P']) + self.assertTrue(filesystem['used'][-1] in ['B', 'K', 'M', 'G', 'T', 'P']) + self.assertTrue(filesystem['free'][-1] in ['B', 'K', 'M', 'G', 'T', 'P']) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/test/windows.py b/test/windows.py new file mode 100644 index 00000000..9581bab6 --- /dev/null +++ b/test/windows.py @@ -0,0 +1,208 @@ +import os +import tarfile +import shutil +from pathlib import Path +import psutil + +def tar(mode, archive, sfiles=None, verbose=False, change_dir=None): + """ + Description: + Create or extract a tar archive. + Args: + mode: 'c' for create or 'x' for extract + archive: the archive file name + files: a list of files to add to the archive (when creating) or extract (when extracting); None to extract all files + verbose: whether to print the names of the files as they are being processed + change_dir: the directory to change to before performing any other operations; None to use the current directory + Usage: + # Create a new archive + tar(mode='c', archive='archive.tar', sfiles=['file1.txt', 'file2.txt']) + + # Extract files from an archive + tar(mode='x', archive='archive.tar') + + # Create a new archive with verbose mode and input directory + tar(mode='c', archive='archive.tar', sfiles='./some_directory', verbose=True) + + # Extract files from an archive with verbose mode and change directory + tar(mode='x', archive='archive.tar', verbose=True, change_dir='./some_directory') + """ + if mode == 'c': + # os.chdir(change_dir) + with tarfile.open(archive, mode='w') as tar: + # check if input option file is a list or string + if isinstance(sfiles, list): + for file in sfiles: + if verbose: + print(f"Adding {file} to {archive}") + tar.add(file) + # take it as a folder name string + else: + for folder_path, subfolders, files in os.walk(sfiles): + for file in files: + if verbose: + print(f"Adding {os.path.join(folder_path, file)} to {archive}") + tar.add(os.path.join(folder_path, file)) + + elif mode == 'x': + with tarfile.open(archive, mode='r') as tar: + # sfiles is set to all files in the archive if not specified + if not sfiles: + sfiles = tar.getnames() + for file in sfiles: + if verbose: + print(f"Extracting {file} from {archive}") + # extra to specified directory + if change_dir: + tar.extract(file, path=change_dir) + else: + tar.extract(file) + +def rm(path, force=False, recursive=False): + """ + Description: + Remove a file or directory. + Args: + path (str): Path of the file or directory to remove. + force (bool): If True, ignore non-existent files and errors. Default is False. + recursive (bool): If True, remove directories and their contents recursively. Default is False. + Usage: + # Remove the file + rm(dst) + # Remove a directory recursively + rm("directory/path", recursive=True) + """ + path_obj = Path(path) + + try: + if path_obj.is_file() or (path_obj.is_symlink() and not path_obj.is_dir()): + path_obj.unlink() + elif path_obj.is_dir() and recursive: + shutil.rmtree(path) + elif path_obj.is_dir(): + raise ValueError("Cannot remove directory without recursive=True") + else: + raise ValueError("File or directory does not exist") + except Exception as e: + if not force: + raise e + +def cp(src, dst, recursive=False, dereference=False, preserve=True): + """ + Description: + Copy a file or directory from source path to destination path. + Args: + src (str): Source file or directory path. + dst (str): Destination file or directory path. + recursive (bool): If True, copy directory and its contents recursively. Default is False. + dereference (bool): If True, always dereference symbolic links. Default is False. + preserve (bool): If True, preserve file metadata. Default is True. + Usage: + src = "source/file/path.txt" + dst = "destination/file/path.txt" + + # Copy the file + cp(src, dst) + + # Copy a directory recursively and dereference symlinks + cp("source/directory", "destination/directory", recursive=True, dereference=True) + """ + src_path = Path(src) + dst_path = Path(dst) + + if dereference: + src_path = src_path.resolve() + + if src_path.is_dir() and recursive: + if preserve: + shutil.copytree(src_path, dst_path, copy_function=shutil.copy2, symlinks=not dereference) + else: + shutil.copytree(src_path, dst_path, symlinks=not dereference) + elif src_path.is_file(): + if preserve: + shutil.copy2(src_path, dst_path) + else: + shutil.copy(src_path, dst_path) + else: + raise ValueError("Source must be a file or a directory with recursive=True") + +def mv(src, dest, force=False): + """ + Description: + Move or rename files and directories. + Args: + src (str): Source file or directory path. + dest (str): Destination file or directory path. + force (bool): If True, overwrite the destination if it exists. Default is False. + Usage: + # Rename a file + mv('old_name.txt', 'new_name.txt') + + # Move a file to a new directory + mv('file.txt', 'new_directory/file.txt') + + # Move a directory to another directory + mv('source_directory', 'destination_directory') + + # Force move (overwrite) a file or directory + mv('source_file.txt', 'existing_destination_file.txt', force=True) + """ + src_path = Path(src) + dest_path = Path(dest) + + if src_path.exists(): + if dest_path.exists() and not force: + raise FileExistsError(f"Destination path '{dest}' already exists and 'force' is not set") + else: + if dest_path.is_file(): + dest_path.unlink() + elif dest_path.is_dir(): + shutil.rmtree(dest_path) + + if src_path.is_file() or src_path.is_dir(): + shutil.move(src, dest) + else: + raise FileNotFoundError(f"Source path '{src}' does not exist") + +def format_size(size, human_readable): + if human_readable: + for unit in ['B', 'K', 'M', 'G', 'T', 'P']: + if size < 1024: + return f"{size:.1f}{unit}" + size /= 1024 + else: + return str(size) + +def df(show_all=False, human_readable=False): + """ + Description: + Get disk usage statistics. + Args: + show_all (bool): If True, include all filesystems. Default is False. + human_readable (bool): If True, format sizes in human readable format. Default is False. + Usage: + filesystems = df(show_all=True, human_readable=True) + for filesystem in filesystems: + print(f"Filesystem: {filesystem['filesystem']}") + print(f"Total: {filesystem['total']}") + print(f"Used: {filesystem['used']}") + print(f"Free: {filesystem['free']}") + print(f"Percent: {filesystem['percent']}%") + print(f"Mountpoint: {filesystem['mountpoint']}") + """ + partitions = psutil.disk_partitions(all=show_all) + result = [] + + for partition in partitions: + usage = psutil.disk_usage(partition.mountpoint) + partition_info = { + 'filesystem': partition.device, + 'total': format_size(usage.total, human_readable), + 'used': format_size(usage.used, human_readable), + 'free': format_size(usage.free, human_readable), + 'percent': usage.percent, + 'mountpoint': partition.mountpoint, + } + result.append(partition_info) + + return result diff --git a/utils.py b/utils.py index 6bd15ebb..b55aa77b 100644 --- a/utils.py +++ b/utils.py @@ -13,6 +13,10 @@ sys.path.append(os.getcwd()) # from modules.timer import Timer import tarfile +import shutil +from pathlib import Path +import psutil + class ModelsRef: def __init__(self): self.models_ref = {} @@ -84,7 +88,8 @@ def upload_folder_to_s3(local_folder_path, bucket_name, s3_folder_path): def upload_folder_to_s3_by_tar(local_folder_path, bucket_name, s3_folder_path): tar_name = f"{os.path.basename(local_folder_path)}.tar" - os.system(f'tar cvf {tar_name} {local_folder_path}') + # os.system(f'tar cvf {tar_name} {local_folder_path}') + tar(mode='c', archive=tar_name, sfiles=local_folder_path, verbose=True) # tar = tarfile.open(tar_path, "w:gz") # for root, dirs, files in os.walk(local_folder_path): # for file in files: @@ -93,7 +98,9 @@ def upload_folder_to_s3_by_tar(local_folder_path, bucket_name, s3_folder_path): # tar.close() s3_client = boto3.client('s3') s3_client.upload_file(tar_name, bucket_name, os.path.join(s3_folder_path, tar_name)) - os.system(f"rm {tar_name}") + # os.system(f"rm {tar_name}") + rm(tar_name, recursive=True) + def upload_file_to_s3(file_name, bucket, directory=None, object_name=None): # If S3 object_name was not specified, use file_name @@ -149,7 +156,7 @@ def download_folder_from_s3(bucket_name, s3_folder_path, local_folder_path): s3_resource = boto3.resource('s3') bucket = s3_resource.Bucket(bucket_name) for obj in bucket.objects.filter(Prefix=s3_folder_path): - obj_dirname = "/".join(os.path.dirname(obj.key).split("/")[1:]) + obj_dirname = os.sep.join(os.path.dirname(obj.key).split("/")[1:]) obj_basename = os.path.basename(obj.key) local_sub_folder_path = os.path.join(local_folder_path, obj_dirname) if not os.path.exists(local_sub_folder_path): @@ -161,11 +168,13 @@ def download_folder_from_s3_by_tar(bucket_name, s3_tar_path, local_tar_path, tar s3_client = boto3.client('s3') s3_client.download_file(bucket_name, s3_tar_path, local_tar_path) # tar_name = os.path.basename(s3_tar_path) - os.system(f"tar xvf {local_tar_path} -C {target_dir}") + # os.system(f"tar xvf {local_tar_path} -C {target_dir}") + tar(mode='x', archive=local_tar_path, verbose=True, change_dir=target_dir) # tar = tarfile.open(local_tar_path, "r") # tar.extractall() # tar.close() - os.system(f"rm {local_tar_path}") + # os.system(f"rm {local_tar_path}") + rm(local_tar_path, recursive=True) def download_file_from_s3(bucket_name, s3_file_path, local_file_path): @@ -235,6 +244,215 @@ def get_variable_from_json(variable_name, filename='sagemaker_ui.json'): return variable_value +""" + Description: Below functions are used to replace existing shell command implementation with os.system method, which is not os agonostic and not recommended. +""" +def tar(mode, archive, sfiles=None, verbose=False, change_dir=None): + """ + Description: + Create or extract a tar archive. + Args: + mode: 'c' for create or 'x' for extract + archive: the archive file name + files: a list of files to add to the archive (when creating) or extract (when extracting); None to extract all files + verbose: whether to print the names of the files as they are being processed + change_dir: the directory to change to before performing any other operations; None to use the current directory + Usage: + # Create a new archive + tar(mode='c', archive='archive.tar', sfiles=['file1.txt', 'file2.txt']) + + # Extract files from an archive + tar(mode='x', archive='archive.tar') + + # Create a new archive with verbose mode and input directory + tar(mode='c', archive='archive.tar', sfiles='./some_directory', verbose=True) + + # Extract files from an archive with verbose mode and change directory + tar(mode='x', archive='archive.tar', verbose=True, change_dir='./some_directory') + """ + if mode == 'c': + # os.chdir(change_dir) + with tarfile.open(archive, mode='w') as tar: + # check if input option file is a list or string + if isinstance(sfiles, list): + for file in sfiles: + if verbose: + print(f"Adding {file} to {archive}") + tar.add(file) + # take it as a folder name string + else: + for folder_path, subfolders, files in os.walk(sfiles): + for file in files: + if verbose: + print(f"Adding {os.path.join(folder_path, file)} to {archive}") + tar.add(os.path.join(folder_path, file)) + + elif mode == 'x': + with tarfile.open(archive, mode='r') as tar: + # sfiles is set to all files in the archive if not specified + if not sfiles: + sfiles = tar.getnames() + for file in sfiles: + if verbose: + print(f"Extracting {file} from {archive}") + # extra to specified directory + if change_dir: + tar.extract(file, path=change_dir) + else: + tar.extract(file) + +def rm(path, force=False, recursive=False): + """ + Description: + Remove a file or directory. + Args: + path (str): Path of the file or directory to remove. + force (bool): If True, ignore non-existent files and errors. Default is False. + recursive (bool): If True, remove directories and their contents recursively. Default is False. + Usage: + # Remove the file + rm(dst) + # Remove a directory recursively + rm("directory/path", recursive=True) + """ + path_obj = Path(path) + + try: + if path_obj.is_file() or (path_obj.is_symlink() and not path_obj.is_dir()): + path_obj.unlink() + elif path_obj.is_dir() and recursive: + shutil.rmtree(path) + elif path_obj.is_dir(): + raise ValueError("Cannot remove directory without recursive=True") + else: + raise ValueError("File or directory does not exist") + except Exception as e: + if not force: + raise e + +def cp(src, dst, recursive=False, dereference=False, preserve=True): + """ + Description: + Copy a file or directory from source path to destination path. + Args: + src (str): Source file or directory path. + dst (str): Destination file or directory path. + recursive (bool): If True, copy directory and its contents recursively. Default is False. + dereference (bool): If True, always dereference symbolic links. Default is False. + preserve (bool): If True, preserve file metadata. Default is True. + Usage: + src = "source/file/path.txt" + dst = "destination/file/path.txt" + + # Copy the file + cp(src, dst) + + # Copy a directory recursively and dereference symlinks + cp("source/directory", "destination/directory", recursive=True, dereference=True) + """ + src_path = Path(src) + dst_path = Path(dst) + try: + if dereference: + src_path = src_path.resolve() + + if src_path.is_dir() and recursive: + if preserve: + shutil.copytree(src_path, dst_path, copy_function=shutil.copy2, symlinks=not dereference) + else: + shutil.copytree(src_path, dst_path, symlinks=not dereference) + elif src_path.is_file(): + if preserve: + shutil.copy2(src_path, dst_path) + else: + shutil.copy(src_path, dst_path) + else: + raise ValueError("Source must be a file or a directory with recursive=True") + except shutil.SameFileError: + print("Source and destination represents the same file.") + + +def mv(src, dest, force=False): + """ + Description: + Move or rename files and directories. + Args: + src (str): Source file or directory path. + dest (str): Destination file or directory path. + force (bool): If True, overwrite the destination if it exists. Default is False. + Usage: + # Rename a file + mv('old_name.txt', 'new_name.txt') + + # Move a file to a new directory + mv('file.txt', 'new_directory/file.txt') + + # Move a directory to another directory + mv('source_directory', 'destination_directory') + + # Force move (overwrite) a file or directory + mv('source_file.txt', 'existing_destination_file.txt', force=True) + """ + src_path = Path(src) + dest_path = Path(dest) + + if src_path.exists(): + if dest_path.exists() and not force: + raise FileExistsError(f"Destination path '{dest}' already exists and 'force' is not set") + else: + if dest_path.is_file(): + dest_path.unlink() + elif dest_path.is_dir(): + shutil.rmtree(dest_path) + + if src_path.is_file() or src_path.is_dir(): + shutil.move(src, dest) + else: + raise FileNotFoundError(f"Source path '{src}' does not exist") + +def format_size(size, human_readable): + if human_readable: + for unit in ['B', 'K', 'M', 'G', 'T', 'P']: + if size < 1024: + return f"{size:.1f}{unit}" + size /= 1024 + else: + return str(size) + +def df(show_all=False, human_readable=False): + """ + Description: + Get disk usage statistics. + Args: + show_all (bool): If True, include all filesystems. Default is False. + human_readable (bool): If True, format sizes in human readable format. Default is False. + Usage: + filesystems = df(show_all=True, human_readable=True) + for filesystem in filesystems: + print(f"Filesystem: {filesystem['filesystem']}") + print(f"Total: {filesystem['total']}") + print(f"Used: {filesystem['used']}") + print(f"Free: {filesystem['free']}") + print(f"Percent: {filesystem['percent']}%") + print(f"Mountpoint: {filesystem['mountpoint']}") + """ + partitions = psutil.disk_partitions(all=show_all) + result = [] + + for partition in partitions: + usage = psutil.disk_usage(partition.mountpoint) + partition_info = { + 'filesystem': partition.device, + 'total': format_size(usage.total, human_readable), + 'used': format_size(usage.used, human_readable), + 'free': format_size(usage.free, human_readable), + 'percent': usage.percent, + 'mountpoint': partition.mountpoint, + } + result.append(partition_info) + + return result + if __name__ == '__main__': import sys