Merge pull request #82 from alexandersokol/fix/filenames-from-civitai

Fix download filenames for models downloaded from CivitAi
update
Alexander Sokol 2024-01-30 16:18:33 +02:00 committed by GitHub
commit f9b983536e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 11 deletions

View File

@ -7,6 +7,7 @@ from tqdm import tqdm
from scripts.mo.dl.downloader import Downloader
from scripts.mo.environment import env
class HttpDownloader(Downloader):
def accepts_url(self, url: str) -> bool:
@ -14,13 +15,19 @@ class HttpDownloader(Downloader):
return parsed_url.scheme in ['http', 'https'] and parsed_url.hostname not in ['drive.google.com', 'mega.nz']
def fetch_filename(self, url):
response = requests.get(url, headers={'Range': 'bytes=0-1'})
api_key = env.api_key()
headers = {'Range': 'bytes=0-1'}
if api_key:
headers['Authorization'] = 'Bearer ' + api_key
response = requests.get(url, headers=headers)
if response.status_code == 200 or response.status_code == 206:
if 'Content-Disposition' in response.headers:
content_disp = response.headers['Content-Disposition']
filename = content_disp.split(';')[1].split('=')[1].strip('\"')
return filename.encode('utf-8').decode('GBK').encode('utf-8').decode(
'utf-8') # Needed to properly encode/decode chinese symbols, have fun.
return (filename.encode('utf-8').decode('GBK').encode('utf-8')
.decode('utf-8')) # Needed to properly encode/decode chinese symbols, have fun.
else:
return None
@ -30,14 +37,14 @@ class HttpDownloader(Downloader):
yield {'bytes_ready': 'None', 'bytes_total': 'None', 'speed_rate': 'None', 'elapsed': 'None'}
apiKey = env.api_key()
if apiKey:
authHeader = {'Content-Type':'application/json',
'Authorization': 'Bearer ' + apiKey}
response = requests.get(url, stream=True, headers=authHeader)
api_key = env.api_key()
if api_key:
auth_header = {'Content-Type': 'application/json',
'Authorization': 'Bearer ' + api_key}
response = requests.get(url, stream=True, headers=auth_header)
else:
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
yield {'bytes_ready': 0, 'bytes_total': total_size, 'speed_rate': 0, 'elapsed': 0}

View File

@ -3,8 +3,8 @@ import logging
import os.path
from typing import Callable
from scripts.mo.models import ModelType, Record
from scripts.mo.data.storage import Storage
from scripts.mo.models import ModelType
STORAGE_SQLITE = 'SQLite'
STORAGE_FIREBASE = 'Firebase'
@ -77,7 +77,7 @@ class Environment:
card_height: Callable[[], str]
is_debug_mode_enabled: Callable[[], bool]
api_key: Callable[[], str]
def is_storage_initialized(self) -> bool:
return hasattr(self, 'storage')