1
0
Fork 0

Merging upstream version 0.9.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-05 13:50:07 +01:00
parent 857951d9ac
commit 161de8690e
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
28 changed files with 1073 additions and 859 deletions

View file

@ -8,8 +8,13 @@
eos_downloader class definition
"""
from __future__ import (absolute_import, division, print_function,
unicode_literals, annotations)
from __future__ import (
absolute_import,
annotations,
division,
print_function,
unicode_literals,
)
import base64
import glob
@ -26,9 +31,14 @@ from loguru import logger
from rich import console
from tqdm import tqdm
from eos_downloader import (ARISTA_DOWNLOAD_URL, ARISTA_GET_SESSION,
ARISTA_SOFTWARE_FOLDER_TREE, EVE_QEMU_FOLDER_PATH,
MSG_INVALID_DATA, MSG_TOKEN_EXPIRED)
from eos_downloader import (
ARISTA_DOWNLOAD_URL,
ARISTA_GET_SESSION,
ARISTA_SOFTWARE_FOLDER_TREE,
EVE_QEMU_FOLDER_PATH,
MSG_INVALID_DATA,
MSG_TOKEN_EXPIRED,
)
from eos_downloader.data import DATA_MAPPING
from eos_downloader.download import DownloadProgressBar
@ -37,11 +47,19 @@ from eos_downloader.download import DownloadProgressBar
console = rich.get_console()
class ObjectDownloader():
class ObjectDownloader:
"""
ObjectDownloader Generic Object to download from Arista.com
"""
def __init__(self, image: str, version: str, token: str, software: str = 'EOS', hash_method: str = 'md5sum'):
def __init__(
self,
image: str,
version: str,
token: str,
software: str = "EOS",
hash_method: str = "md5sum",
):
"""
__init__ Class constructor
@ -70,10 +88,10 @@ class ObjectDownloader():
self.hash_method = hash_method
self.timeout = 5
# Logging
logger.debug(f'Filename built by _build_filename is {self.filename}')
logger.debug(f"Filename built by _build_filename is {self.filename}")
def __str__(self) -> str:
return f'{self.software} - {self.image} - {self.version}'
return f"{self.software} - {self.image} - {self.version}"
# def __repr__(self):
# return str(self.__dict__)
@ -102,16 +120,18 @@ class ObjectDownloader():
str:
Filename to search for on Arista.com
"""
logger.info('start build')
logger.info("start build")
if self.software in DATA_MAPPING:
logger.info(f'software in data mapping: {self.software}')
logger.info(f"software in data mapping: {self.software}")
if self.image in DATA_MAPPING[self.software]:
logger.info(f'image in data mapping: {self.image}')
logger.info(f"image in data mapping: {self.image}")
return f"{DATA_MAPPING[self.software][self.image]['prepend']}-{self.version}{DATA_MAPPING[self.software][self.image]['extension']}"
return f"{DATA_MAPPING[self.software]['default']['prepend']}-{self.version}{DATA_MAPPING[self.software]['default']['extension']}"
raise ValueError(f'Incorrect value for software {self.software}')
raise ValueError(f"Incorrect value for software {self.software}")
def _parse_xml_for_path(self, root_xml: ET.ElementTree, xpath: str, search_file: str) -> str:
def _parse_xml_for_path(
self, root_xml: ET.ElementTree, xpath: str, search_file: str
) -> str:
# sourcery skip: remove-unnecessary-cast
"""
_parse_xml Read and extract data from XML using XPATH
@ -132,18 +152,18 @@ class ObjectDownloader():
str
File Path on Arista server side
"""
logger.debug(f'Using xpath {xpath}')
logger.debug(f'Search for file {search_file}')
console.print(f'🔎 Searching file {search_file}')
logger.debug(f"Using xpath {xpath}")
logger.debug(f"Search for file {search_file}")
console.print(f"🔎 Searching file {search_file}")
for node in root_xml.findall(xpath):
# logger.debug('Found {}', node.text)
if str(node.text).lower() == search_file.lower():
path = node.get('path')
console.print(f' -> Found file at {path}')
path = node.get("path")
console.print(f" -> Found file at {path}")
logger.info(f'Found {node.text} at {node.get("path")}')
return str(node.get('path')) if node.get('path') is not None else ''
logger.error(f'Requested file ({self.filename}) not found !')
return ''
return str(node.get("path")) if node.get("path") is not None else ""
logger.error(f"Requested file ({self.filename}) not found !")
return ""
def _get_hash(self, file_path: str) -> str:
"""
@ -165,10 +185,10 @@ class ObjectDownloader():
dl_rich_progress_bar = DownloadProgressBar()
dl_rich_progress_bar.download(urls=[hash_url], dest_dir=file_path)
hash_downloaded = f"{file_path}/{os.path.basename(remote_hash_file)}"
hash_content = 'unset'
with open(hash_downloaded, 'r', encoding='utf-8') as f:
hash_content = "unset"
with open(hash_downloaded, "r", encoding="utf-8") as f:
hash_content = f.read()
return hash_content.split(' ')[0]
return hash_content.split(" ")[0]
@staticmethod
def _compute_hash_md5sum(file: str, hash_expected: str) -> bool:
@ -195,7 +215,9 @@ class ObjectDownloader():
hash_md5.update(chunk)
if hash_md5.hexdigest() == hash_expected:
return True
logger.warning(f'Downloaded file is corrupt: local md5 ({hash_md5.hexdigest()}) is different to md5 from arista ({hash_expected})')
logger.warning(
f"Downloaded file is corrupt: local md5 ({hash_md5.hexdigest()}) is different to md5 from arista ({hash_expected})"
)
return False
@staticmethod
@ -223,10 +245,12 @@ class ObjectDownloader():
hash_sha512.update(chunk)
if hash_sha512.hexdigest() == hash_expected:
return True
logger.warning(f'Downloaded file is corrupt: local sha512 ({hash_sha512.hexdigest()}) is different to sha512 from arista ({hash_expected})')
logger.warning(
f"Downloaded file is corrupt: local sha512 ({hash_sha512.hexdigest()}) is different to sha512 from arista ({hash_expected})"
)
return False
def _get_folder_tree(self) -> ET.ElementTree:
def get_folder_tree(self) -> ET.ElementTree:
"""
_get_folder_tree Download XML tree from Arista server
@ -237,15 +261,17 @@ class ObjectDownloader():
"""
if self.session_id is None:
self.authenticate()
jsonpost = {'sessionCode': self.session_id}
result = requests.post(ARISTA_SOFTWARE_FOLDER_TREE, data=json.dumps(jsonpost), timeout=self.timeout)
jsonpost = {"sessionCode": self.session_id}
result = requests.post(
ARISTA_SOFTWARE_FOLDER_TREE, data=json.dumps(jsonpost), timeout=self.timeout
)
try:
folder_tree = result.json()["data"]["xml"]
return ET.ElementTree(ET.fromstring(folder_tree))
except KeyError as error:
logger.error(MSG_INVALID_DATA)
logger.error(f'Server returned: {error}')
console.print(f'{MSG_INVALID_DATA}', style="bold red")
logger.error(f"Server returned: {error}")
console.print(f"{MSG_INVALID_DATA}", style="bold red")
sys.exit(1)
def _get_remote_filepath(self) -> str:
@ -259,12 +285,14 @@ class ObjectDownloader():
str
Remote path of the file to download
"""
root = self._get_folder_tree()
root = self.get_folder_tree()
logger.debug("GET XML content from ARISTA.com")
xpath = f'.//dir[@label="{self.software}"]//file'
return self._parse_xml_for_path(root_xml=root, xpath=xpath, search_file=self.filename)
return self._parse_xml_for_path(
root_xml=root, xpath=xpath, search_file=self.filename
)
def _get_remote_hashpath(self, hash_method: str = 'md5sum') -> str:
def _get_remote_hashpath(self, hash_method: str = "md5sum") -> str:
"""
_get_remote_hashpath Helper to get path of the hash's file to download
@ -275,16 +303,16 @@ class ObjectDownloader():
str
Remote path of the hash's file to download
"""
root = self._get_folder_tree()
root = self.get_folder_tree()
logger.debug("GET XML content from ARISTA.com")
xpath = f'.//dir[@label="{self.software}"]//file'
return self._parse_xml_for_path(
root_xml=root,
xpath=xpath,
search_file=f'{self.filename}.{hash_method}',
search_file=f"{self.filename}.{hash_method}",
)
def _get_url(self, remote_file_path: str) -> str:
def _get_url(self, remote_file_path: str) -> str:
"""
_get_url Get URL to use for downloading file from Arista server
@ -302,13 +330,15 @@ class ObjectDownloader():
"""
if self.session_id is None:
self.authenticate()
jsonpost = {'sessionCode': self.session_id, 'filePath': remote_file_path}
result = requests.post(ARISTA_DOWNLOAD_URL, data=json.dumps(jsonpost), timeout=self.timeout)
if 'data' in result.json() and 'url' in result.json()['data']:
jsonpost = {"sessionCode": self.session_id, "filePath": remote_file_path}
result = requests.post(
ARISTA_DOWNLOAD_URL, data=json.dumps(jsonpost), timeout=self.timeout
)
if "data" in result.json() and "url" in result.json()["data"]:
# logger.debug('URL to download file is: {}', result.json())
return result.json()["data"]["url"]
logger.critical(f'Server returns following message: {result.json()}')
return ''
logger.critical(f"Server returns following message: {result.json()}")
return ""
@staticmethod
def _download_file_raw(url: str, file_path: str) -> str:
@ -331,31 +361,40 @@ class ObjectDownloader():
"""
chunkSize = 1024
r = requests.get(url, stream=True, timeout=5)
with open(file_path, 'wb') as f:
pbar = tqdm(unit="B", total=int(r.headers['Content-Length']), unit_scale=True, unit_divisor=1024)
with open(file_path, "wb") as f:
pbar = tqdm(
unit="B",
total=int(r.headers["Content-Length"]),
unit_scale=True,
unit_divisor=1024,
)
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk:
pbar.update(len(chunk))
f.write(chunk)
return file_path
def _download_file(self, file_path: str, filename: str, rich_interface: bool = True) -> Union[None, str]:
def _download_file(
self, file_path: str, filename: str, rich_interface: bool = True
) -> Union[None, str]:
remote_file_path = self._get_remote_filepath()
logger.info(f'File found on arista server: {remote_file_path}')
logger.info(f"File found on arista server: {remote_file_path}")
file_url = self._get_url(remote_file_path=remote_file_path)
if file_url is not False:
if not rich_interface:
return self._download_file_raw(url=file_url, file_path=os.path.join(file_path, filename))
return self._download_file_raw(
url=file_url, file_path=os.path.join(file_path, filename)
)
rich_downloader = DownloadProgressBar()
rich_downloader.download(urls=[file_url], dest_dir=file_path)
return os.path.join(file_path, filename)
logger.error(f'Cannot download file {file_path}')
logger.error(f"Cannot download file {file_path}")
return None
@staticmethod
def _create_destination_folder(path: str) -> None:
# os.makedirs(path, mode, exist_ok=True)
os.system(f'mkdir -p {path}')
os.system(f"mkdir -p {path}")
@staticmethod
def _disable_ztp(file_path: str) -> None:
@ -379,24 +418,29 @@ class ObjectDownloader():
"""
credentials = (base64.b64encode(self.token.encode())).decode("utf-8")
session_code_url = ARISTA_GET_SESSION
jsonpost = {'accessToken': credentials}
jsonpost = {"accessToken": credentials}
result = requests.post(session_code_url, data=json.dumps(jsonpost), timeout=self.timeout)
result = requests.post(
session_code_url, data=json.dumps(jsonpost), timeout=self.timeout
)
if result.json()["status"]["message"] in[ 'Access token expired', 'Invalid access token']:
console.print(f'{MSG_TOKEN_EXPIRED}', style="bold red")
if result.json()["status"]["message"] in [
"Access token expired",
"Invalid access token",
]:
console.print(f"{MSG_TOKEN_EXPIRED}", style="bold red")
logger.error(MSG_TOKEN_EXPIRED)
return False
try:
if 'data' in result.json():
if "data" in result.json():
self.session_id = result.json()["data"]["session_code"]
logger.info('Authenticated on arista.com')
logger.info("Authenticated on arista.com")
return True
logger.debug(f'{result.json()}')
logger.debug(f"{result.json()}")
return False
except KeyError as error_arista:
logger.error(f'Error: {error_arista}')
logger.error(f"Error: {error_arista}")
sys.exit(1)
def download_local(self, file_path: str, checksum: bool = False) -> bool:
@ -422,25 +466,33 @@ class ObjectDownloader():
bool
True if everything went well, False if any problem appears
"""
file_downloaded = str(self._download_file(file_path=file_path, filename=self.filename))
file_downloaded = str(
self._download_file(file_path=file_path, filename=self.filename)
)
# Check file HASH
hash_result = False
if checksum:
logger.info('🚀 Running checksum validation')
console.print('🚀 Running checksum validation')
if self.hash_method == 'md5sum':
logger.info("🚀 Running checksum validation")
console.print("🚀 Running checksum validation")
if self.hash_method == "md5sum":
hash_expected = self._get_hash(file_path=file_path)
hash_result = self._compute_hash_md5sum(file=file_downloaded, hash_expected=hash_expected)
elif self.hash_method == 'sha512sum':
hash_result = self._compute_hash_md5sum(
file=file_downloaded, hash_expected=hash_expected
)
elif self.hash_method == "sha512sum":
hash_expected = self._get_hash(file_path=file_path)
hash_result = self._compute_hash_sh512sum(file=file_downloaded, hash_expected=hash_expected)
hash_result = self._compute_hash_sh512sum(
file=file_downloaded, hash_expected=hash_expected
)
if not hash_result:
logger.error('Downloaded file is corrupted, please check your connection')
console.print('❌ Downloaded file is corrupted, please check your connection')
logger.error("Downloaded file is corrupted, please check your connection")
console.print(
"❌ Downloaded file is corrupted, please check your connection"
)
return False
logger.info('Downloaded file is correct.')
console.print('✅ Downloaded file is correct.')
logger.info("Downloaded file is correct.")
console.print("✅ Downloaded file is correct.")
return True
def provision_eve(self, noztp: bool = False, checksum: bool = True) -> None:
@ -466,7 +518,7 @@ class ObjectDownloader():
# Build image name to use in folder path
eos_image_name = self.filename.rstrip(".vmdk").lower()
if noztp:
eos_image_name = f'{eos_image_name}-noztp'
eos_image_name = f"{eos_image_name}-noztp"
# Create full path for EVE-NG
file_path = os.path.join(EVE_QEMU_FOLDER_PATH, eos_image_name.rstrip())
# Create folders in filesystem
@ -474,20 +526,23 @@ class ObjectDownloader():
# Download file to local destination
file_downloaded = self._download_file(
file_path=file_path, filename=self.filename)
file_path=file_path, filename=self.filename
)
# Convert to QCOW2 format
file_qcow2 = os.path.join(file_path, "hda.qcow2")
logger.info('Converting VMDK to QCOW2 format')
console.print('🚀 Converting VMDK to QCOW2 format...')
logger.info("Converting VMDK to QCOW2 format")
console.print("🚀 Converting VMDK to QCOW2 format...")
os.system(f'$(which qemu-img) convert -f vmdk -O qcow2 {file_downloaded} {file_qcow2}')
os.system(
f"$(which qemu-img) convert -f vmdk -O qcow2 {file_downloaded} {file_qcow2}"
)
logger.info('Applying unl_wrapper to fix permissions')
console.print('Applying unl_wrapper to fix permissions')
logger.info("Applying unl_wrapper to fix permissions")
console.print("Applying unl_wrapper to fix permissions")
os.system('/opt/unetlab/wrappers/unl_wrapper -a fixpermissions')
os.system(f'rm -f {file_downloaded}')
os.system("/opt/unetlab/wrappers/unl_wrapper -a fixpermissions")
os.system(f"rm -f {file_downloaded}")
if noztp:
self._disable_ztp(file_path=file_path)
@ -502,12 +557,12 @@ class ObjectDownloader():
version (str):
image_name (str, optional): Image name to use. Defaults to "arista/ceos".
"""
docker_image = f'{image_name}:{self.version}'
logger.info(f'Importing image {self.filename} to {docker_image}')
console.print(f'🚀 Importing image {self.filename} to {docker_image}')
os.system(f'$(which docker) import {self.filename} {docker_image}')
for filename in glob.glob(f'{self.filename}*'):
docker_image = f"{image_name}:{self.version}"
logger.info(f"Importing image {self.filename} to {docker_image}")
console.print(f"🚀 Importing image {self.filename} to {docker_image}")
os.system(f"$(which docker) import {self.filename} {docker_image}")
for filename in glob.glob(f"{self.filename}*"):
try:
os.remove(filename)
except FileNotFoundError:
console.print(f'File not found: {filename}')
console.print(f"File not found: {filename}")