Skip to content

Segmentation Fault When Using PyJulia Inside of PyTorch Custom Autograd Function #518

Open
@THargreaves

Description

@THargreaves

I have a (vector-to-scalar) function and corresponding derivative function written in Julia that I am unable to translate to Python. I would like to use these within PyTorch by defining a custom autograd function. As a simple, reproducible example, let's say the function is sum():

import numpy as np
import torch
from julia import Main

class JuliaSum(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        
        x = input.cpu().detach().numpy()

        return torch.FloatTensor([Main.sum(x)]).to('cuda')
        # return torch.FloatTensor([np.sum(x)]).to('cuda')

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        x = input.cpu().detach().numpy()

        y = torch.FloatTensor(Main.ones(len(x))).to('cuda')
        # y = torch.FloatTensor(np.ones(len(x))).to('cuda')

        return grad_output * y

input = torch.FloatTensor([0.1, 0.2, 0.3]).to('cuda').requires_grad_()

# Works — outputs `tensor([0.6000], device='cuda:0', grad_fn=<JuliaSumBackward>)`
y = JuliaSum.apply(input)
print(y)

# Works — outputs `tensor([1., 1., 1.], device='cuda:0') `
x = input.cpu().detach().numpy().astype(np.float64)
y_test = torch.FloatTensor(Main.ones(len(x))).to('cuda')
print(torch.ones(1).to('cuda') * y_test)

# Doesn't work — segmentation fault
y.backward(torch.ones(1).to('cuda'))
print(input.grad)

Calling the forward method works fine, as does running the code contained in the backward method from the global scope. However, when I call the backward method, I receive:

signal (11): Segmentation fault
in expression starting at none:0
Allocations: 3652709 (Pool: 3650429; Big: 2280); GC: 5
Segmentation fault (core dumped)         

The exact line command causing the issue is Main.ones(len(x)). Replacing this with Main.ones(3) still causes a segmentation fault, so it appears to be an issue with PyJulia accessing memory that has been deallocated.

Also note that when I replace the two calls to Julia with the corresponding NumPy commands (left commented-out), the backward method works fine. The code also works when all tensors are on the CPU but my application requires GPU-acceleration.

What is causing this segmentation fault, and how can alter my code to avoid it whilst keeping PyTorch tensors on the GPU?


I've included a Dockerfile that matches my environment to make reproducing this issue as simple as possible. For reference, I am using an RTX 3060.

FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04 

ARG PYTHON_VERSION=3.10.1
ARG JULIA_VERSION=1.7.1

ENV container docker
ENV DEBIAN_FRONTEND noninteractive
ENV LANG en_US.utf8
ENV MAKEFLAGS -j4

RUN mkdir /app
WORKDIR /app

# DEPENDENCIES
#===========================================
RUN apt-get update -y && \
    apt-get install -y gcc make wget libffi-dev \
        build-essential libssl-dev zlib1g-dev \
        libbz2-dev libreadline-dev libsqlite3-dev \
        libncurses5-dev libncursesw5-dev xz-utils \
        git

# INSTALL PYTHON
#===========================================
RUN wget https://www.python.org/ftp/python/$PYTHON_VERSION/Python-$PYTHON_VERSION.tgz && \
    tar -zxf Python-$PYTHON_VERSION.tgz && \
    cd Python-$PYTHON_VERSION && \
    ./configure --with-ensurepip=install --enable-shared && make && make install && \
    ldconfig && \
    ln -sf python3 /usr/local/bin/python
RUN python -m pip install --upgrade pip setuptools wheel && \
    python -m pip install julia numpy torch

# INSTALL JULIA
#====================================
RUN wget https://raw.githubusercontent.com/abelsiqueira/jill/main/jill.sh && \
    bash /app/jill.sh -y -v $JULIA_VERSION && \
    export PYTHON="python" && \
    julia -e 'using Pkg; ENV["PYTHON"] = "/usr/local/bin/python"' && \
    python -c 'import julia; julia.install()'

# CLEAN UP
#===========================================
RUN rm -rf /app/jill.sh \
    /opt/julias/*.tar.gz \
    /app/Python-$PYTHON_VERSION.tgz
RUN apt-get purge -y gcc make wget zlib1g-dev libffi-dev libssl-dev \
        libbz2-dev libreadline-dev \
        libncurses5-dev libncursesw5-dev xz-utils && \
    apt-get autoremove -y

CMD ["/bin/bash"]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions