Skip to content

Suggestion: Drop .pth models and move to safetensors #42

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
alexlnkp opened this issue Jun 12, 2024 · 10 comments
Open

Suggestion: Drop .pth models and move to safetensors #42

alexlnkp opened this issue Jun 12, 2024 · 10 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@alexlnkp
Copy link
Contributor

alexlnkp commented Jun 12, 2024

This idea has been around for quite some time, the main reasoning is that .pth files are fundamentally unsafe, allowing for remote code execution if injected. This brings us to the solution the original repo had - usintg SHA256 for the main weights to ensure that at least the official weights haven't been tampered with.
However, this problem has a much simpler fix: Using safetensors instead of raw weights, pth. This way no validity check is even required, therefore startup time is sped up significantly. This also improves security as users will not be able to inject any malicious code in the raw weights.

I, sadly, don't remember who was the author of the original idea; however i believe it was notfelt from discord server AIHub

@fumiama fumiama added enhancement New feature or request help wanted Extra attention is needed labels Jun 12, 2024
@fumiama
Copy link
Owner

fumiama commented Jun 12, 2024

There're many outside pths like uvr5. We can drop those support first, then considering the safetensor implementation.

P.S. The hash is not only for the safety, but also for the models are large. In case of the models are broken during downloading, a check at start time is necessary. Well, we can wrap a lazy-check to check the hash when the model is to be loaded.

@alexlnkp
Copy link
Contributor Author

There're many outside pths like uvr5. We can drop those support first, then considering the safetensor implementation.

That sounds like a good plan!

P.S. The hash is not only for the safety, but also for the models are large. In case of the models are broken during downloading, a check at start time is necessary. Well, we can wrap a lazy-check to check the hash when the model is to be loaded.

Huh, I see. Lazy check does sound like the best solution in that case.

@fumiama
Copy link
Owner

fumiama commented Jun 12, 2024

Lazy check does sound like the best solution in that case.

Check just one pth can only spend less than 1s, maybe acceptable.

@alexlnkp
Copy link
Contributor Author

alexlnkp commented Jun 12, 2024

Check just one pth can only spend less than 1s, maybe acceptable.

I mean, the check is simply just "Calculate hash of the pth and compare to the stored hash", this can be done in C for a check that takes <0.1s to do. Not sure if this is a good solution right now, since RVC is not yet structured like a module, so the hash checker has to be built externally and used as a python module (at least until the RVC is made into a module)

I can start working on the hash checker in C and bind it to python functions for easy access as a module.

@fumiama
Copy link
Owner

fumiama commented Jun 12, 2024

Well, I don't know you know it or not, the python standard hashlib is not a pure python one, but written in c, if my memory is correct.

@alexlnkp
Copy link
Contributor Author

alexlnkp commented Jun 12, 2024

Well, I don't know you know it or not, the python standard hashlib is not a pure python one, but written in c, if my memory is correct.

Huh, I didn't know that.
However, I meant more of fully RVC-specific implementation, with checking all of the files against their hashes, instead of using hashlib for all of the files independently.

Since the number of files is known beforehand and all of their hashes are accessible, this might improve performance since we'll do a single call to C for hash-checking of all of the files.

So, if we don't hardcode the values in, it'll be:

import RVC_Hash
def check_all_assets(update=False) -> bool:
    BASE_DIR = Path(__file__).resolve().parent.parent.parent

    logger.info("checking hubret & rmvpe...")

    if not check_model(
        BASE_DIR / "assets" / "hubert",
        "hubert_base.pt",
        os.environ["sha256_hubert_base_pt"],
        update,
    ):
        return False
    if not check_model(
        BASE_DIR / "assets" / "rmvpe",
        "rmvpe.pt",
        os.environ["sha256_rmvpe_pt"],
        update,
    ):
        return False
    if not check_model(
        BASE_DIR / "assets" / "rmvpe",
        "rmvpe.onnx",
        os.environ["sha256_rmvpe_onnx"],
        update,
    ):
        return False

    rvc_models_dir = BASE_DIR / "assets" / "pretrained"
    logger.info("checking pretrained models...")
    model_names = [
        "D32k.pth",
        "D40k.pth",
        "D48k.pth",
        "G32k.pth",
        "G40k.pth",
        "G48k.pth",
        "f0D32k.pth",
        "f0D40k.pth",
        "f0D48k.pth",
        "f0G32k.pth",
        "f0G40k.pth",
        "f0G48k.pth",
    ]
	RVC_Hash.check_hashes(model_names)

Or in case if we hardcode the values in:

import RVC_Hash
def check_all_assets(update=False) -> bool:
    BASE_DIR = Path(__file__).resolve().parent.parent.parent

    logger.info("checking hubret & rmvpe...")

    if not check_model(
        BASE_DIR / "assets" / "hubert",
        "hubert_base.pt",
        os.environ["sha256_hubert_base_pt"],
        update,
    ):
        return False
    if not check_model(
        BASE_DIR / "assets" / "rmvpe",
        "rmvpe.pt",
        os.environ["sha256_rmvpe_pt"],
        update,
    ):
        return False
    if not check_model(
        BASE_DIR / "assets" / "rmvpe",
        "rmvpe.onnx",
        os.environ["sha256_rmvpe_onnx"],
        update,
    ):
        return False

    rvc_models_dir = BASE_DIR / "assets" / "pretrained"
    logger.info("checking pretrained models...")
	RVC_Hash.check_hashes()

Instead of:

def check_all_assets(update=False) -> bool:
    BASE_DIR = Path(__file__).resolve().parent.parent.parent

    logger.info("checking hubret & rmvpe...")

    if not check_model(
        BASE_DIR / "assets" / "hubert",
        "hubert_base.pt",
        os.environ["sha256_hubert_base_pt"],
        update,
    ):
        return False
    if not check_model(
        BASE_DIR / "assets" / "rmvpe",
        "rmvpe.pt",
        os.environ["sha256_rmvpe_pt"],
        update,
    ):
        return False
    if not check_model(
        BASE_DIR / "assets" / "rmvpe",
        "rmvpe.onnx",
        os.environ["sha256_rmvpe_onnx"],
        update,
    ):
        return False

    rvc_models_dir = BASE_DIR / "assets" / "pretrained"
    logger.info("checking pretrained models...")
    model_names = [
        "D32k.pth",
        "D40k.pth",
        "D48k.pth",
        "G32k.pth",
        "G40k.pth",
        "G48k.pth",
        "f0D32k.pth",
        "f0D40k.pth",
        "f0D48k.pth",
        "f0G32k.pth",
        "f0G40k.pth",
        "f0G48k.pth",
    ]
    for model in model_names:
        menv = model.replace(".", "_")
        if not check_model(
            rvc_models_dir, model, os.environ[f"sha256_v1_{menv}"], update
        ):
            return False

As Python's for loops might have a small overhead unlike C-loops.

@fumiama
Copy link
Owner

fumiama commented Jun 12, 2024

Well, if you want to write a specialized program to do this stuff, I will not refuse it but it should be a platform-independent program, which can be run under Windows, Linux, MacOS, etc. and with the architecture of amd64, arm64, etc.

@alexlnkp
Copy link
Contributor Author

Well, if you want to write a specialized program to do this stuff, I will not refuse it but it should be a platform-independent program, which can be run under Windows, Linux, MacOS, etc. and with the architecture of amd64, arm64, etc.

Noted! Will also attempt to make it work under both little endian and big endian!

@blaisewf
Copy link
Contributor

some interesting info here https://huggingface.co/docs/hub/security-pickle

@TheTrustedComputer
Copy link

Alternatively, you can explicitly call torch.load with the argument weights_only=True. This will be the default in future PyTorch releases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

4 participants