Fix download filenames for models downloaded from CivitAi
parent
e713a1e71d
commit
1f8a804573
|
|
@ -7,6 +7,7 @@ from tqdm import tqdm
|
|||
from scripts.mo.dl.downloader import Downloader
|
||||
from scripts.mo.environment import env
|
||||
|
||||
|
||||
class HttpDownloader(Downloader):
|
||||
|
||||
def accepts_url(self, url: str) -> bool:
|
||||
|
|
@ -14,13 +15,19 @@ class HttpDownloader(Downloader):
|
|||
return parsed_url.scheme in ['http', 'https'] and parsed_url.hostname not in ['drive.google.com', 'mega.nz']
|
||||
|
||||
def fetch_filename(self, url):
|
||||
response = requests.get(url, headers={'Range': 'bytes=0-1'})
|
||||
|
||||
api_key = env.api_key()
|
||||
headers = {'Range': 'bytes=0-1'}
|
||||
if api_key:
|
||||
headers['Authorization'] = 'Bearer ' + api_key
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 200 or response.status_code == 206:
|
||||
if 'Content-Disposition' in response.headers:
|
||||
content_disp = response.headers['Content-Disposition']
|
||||
filename = content_disp.split(';')[1].split('=')[1].strip('\"')
|
||||
return filename.encode('utf-8').decode('GBK').encode('utf-8').decode(
|
||||
'utf-8') # Needed to properly encode/decode chinese symbols, have fun.
|
||||
return (filename.encode('utf-8').decode('GBK').encode('utf-8')
|
||||
.decode('utf-8')) # Needed to properly encode/decode chinese symbols, have fun.
|
||||
else:
|
||||
return None
|
||||
|
||||
|
|
@ -30,14 +37,14 @@ class HttpDownloader(Downloader):
|
|||
|
||||
yield {'bytes_ready': 'None', 'bytes_total': 'None', 'speed_rate': 'None', 'elapsed': 'None'}
|
||||
|
||||
apiKey = env.api_key()
|
||||
if apiKey:
|
||||
authHeader = {'Content-Type':'application/json',
|
||||
'Authorization': 'Bearer ' + apiKey}
|
||||
response = requests.get(url, stream=True, headers=authHeader)
|
||||
api_key = env.api_key()
|
||||
if api_key:
|
||||
auth_header = {'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer ' + api_key}
|
||||
response = requests.get(url, stream=True, headers=auth_header)
|
||||
else:
|
||||
response = requests.get(url, stream=True)
|
||||
|
||||
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
|
||||
yield {'bytes_ready': 0, 'bytes_total': total_size, 'speed_rate': 0, 'elapsed': 0}
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ import logging
|
|||
import os.path
|
||||
from typing import Callable
|
||||
|
||||
from scripts.mo.models import ModelType, Record
|
||||
from scripts.mo.data.storage import Storage
|
||||
from scripts.mo.models import ModelType
|
||||
|
||||
STORAGE_SQLITE = 'SQLite'
|
||||
STORAGE_FIREBASE = 'Firebase'
|
||||
|
|
@ -77,7 +77,7 @@ class Environment:
|
|||
card_height: Callable[[], str]
|
||||
is_debug_mode_enabled: Callable[[], bool]
|
||||
api_key: Callable[[], str]
|
||||
|
||||
|
||||
def is_storage_initialized(self) -> bool:
|
||||
return hasattr(self, 'storage')
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue