OiO.lk Community platform!

Oio.lk is an excellent forum for developers, providing a wide range of resources, discussions, and support for those in the developer community. Join oio.lk today to connect with like-minded professionals, share insights, and stay updated on the latest trends and technologies in the development field.
  You need to log in or register to access the solved answers to this problem.
  • You have reached the maximum number of guest views allowed
  • Please register below to remove this limitation

I'm having compatibility errors with Jax and Cuda on Dalle mini colab notebook

  • Thread starter Thread starter Mark Alonso
  • Start date Start date
M

Mark Alonso

Guest
I am running the dalle mini inference pipeline in google colab, but when I get to the part of generating the image it shows me this error:

Code:
from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange

print(f"Prompts: {prompts}\n")
# generate images
images = []
for i in trange(max(n_predictions // jax.device_count(), 1)):
    # get a new key
    key, subkey = jax.random.split(key)
    # generate images
    encoded_images = p_generate(
        tokenized_prompt,
        shard_prng_key(subkey),
        params,
        gen_top_k,
        gen_top_p,
        temperature,
        cond_scale,
    )
    # remove BOS
    encoded_images = encoded_images.sequences[..., 1:]
    # decode images
    decoded_images = p_decode(encoded_images, vqgan_params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
    for decoded_img in decoded_images:
        img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
        images.append(img)
        display(img)
        print()

The error: XlaRuntimeError: INTERNAL: CustomCall failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: cudaGetErrorString symbol not found

Apparently it is due to the necessary versions of cuda or jax, does anyone know which versions should be used? These are the installed dependencies and there is a comment on what I tried to correct it

Code:
# Required only for colab environments + GPU

!pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
#!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
#!pip install --upgrade jax jaxlib


# Install required libraries
!pip install -q dalle-mini
!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git`

I have tried to install the versions that it tells me were necessary, but that led to more incompatibilities

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. chex 0.1.86 requires jax>=0.4.16, but you have jax 0.3.25 which is incompatible. flax 0.8.4 requires jax>=0.4.19, but you have jax 0.3.25 which is incompatible. orbax-checkpoint 0.4.4 requires jax>=0.4.9, but you have jax 0.3.25 which is incompatible.
<p>I am running the dalle mini inference pipeline in google colab, but when I get to the part of generating the image it shows me this error:</p>
<pre><code>from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange

print(f"Prompts: {prompts}\n")
# generate images
images = []
for i in trange(max(n_predictions // jax.device_count(), 1)):
# get a new key
key, subkey = jax.random.split(key)
# generate images
encoded_images = p_generate(
tokenized_prompt,
shard_prng_key(subkey),
params,
gen_top_k,
gen_top_p,
temperature,
cond_scale,
)
# remove BOS
encoded_images = encoded_images.sequences[..., 1:]
# decode images
decoded_images = p_decode(encoded_images, vqgan_params)
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
for decoded_img in decoded_images:
img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
images.append(img)
display(img)
print()
</code></pre>
<p>The error:
XlaRuntimeError: INTERNAL: CustomCall failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: cudaGetErrorString symbol not found</p>
<p>Apparently it is due to the necessary versions of cuda or jax, does anyone know which versions should be used?
These are the installed dependencies and there is a comment on what I tried to correct it</p>
<pre><code># Required only for colab environments + GPU

!pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
#!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
#!pip install --upgrade jax jaxlib


# Install required libraries
!pip install -q dalle-mini
!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git`

I have tried to install the versions that it tells me were necessary, but that led to more incompatibilities
</code></pre>
<p>ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
chex 0.1.86 requires jax>=0.4.16, but you have jax 0.3.25 which is incompatible.
flax 0.8.4 requires jax>=0.4.19, but you have jax 0.3.25 which is incompatible.
orbax-checkpoint 0.4.4 requires jax>=0.4.9, but you have jax 0.3.25 which is incompatible.</p>
 

Latest posts

Top