Rewriting the same AnimeGAN demo over and over

Joe Bowser
Joe Bowser
Cover Image for Rewriting the same AnimeGAN demo over and over

OK....so, back in 2022, when I was still at Adobe, I was invited to speak at PyTorch Conference in New Orleans. Despite the travel freeze, I managed to get the company to cover some of my costs and I visited New Orleans. I eventually did this talk where I pitted PyTorch Mobile, TFLite and ONNX Runtime against each other on a Pixel 6. I wanted to do the demo live like I did at the Tensorflow conference, but Meta wasn't able to accomodate that so I took screenshots and fumbled through it. It's not my best talk, but it's my most recent, and my last talk as an Adobe employee.

Anyway, the thing that I was proud at was that I managed to demo something, which was turning out to be harder and harder as AI became more and more competitive, and while academics have NeurIPS and other places, old hackers had nothing ever since O'Reilly stopped doing OSCON. Sure, I could probably do a talk at Open Source Summit, but since I've never been a cloud developer, it'd be really hard to figure out which stream I should be a part of. Also, it got harder and harder to go do talks. I always did talks because it meant I could go to conventions for free, even though in hindsight it hasn't helped me professionally at all.

But besides that, I wanted to get back to this and do a follow up to this, because despite the world moving on to Generative models like LLMs, VLMS and Diffusion-based models which is a hive of scum and villany, I was curious about all the different ways you could run a model on a device in 2024/2025.

So, old open source repos in hand, I embarked on an adventure. I already ran AnimeGAN wtih TFLite, ONNX Runtime and PyTorch mobile. I wanted to see if I could run it these ways:

  • ONNX Runtime on iOS
  • CoreML
  • Rust using pyke-ort
  • Candle
  • ONNX Runtime using the Java SDK in Kotlin
  • Executorch using the Java SDK in Kotlin

ONNX Runtime in C++

Honestly, this is a really basic model, and I'm a bit embarassed to admit that I used OpenCV for the pixel shifting because of laziness. I should have just wrote the nested for loop to do this, or used the header libraries. This is boilerplate, but it's irritating to re-write in every language.

    cv::Mat cropped, resizedImage, preMat, out;
    // Resize to something acceptable for the NN to process
    cropped = centerCrop(input);

    cv::resize(cropped,
               resizedImage,
               cv::Size(512, 512),
               cv::InterpolationFlags::INTER_CUBIC);

    resizedImage.convertTo(resizedImage, CV_32F, 1.0 / 255);

    // Normalize to [-1,1] range
    resizedImage = resizedImage * 2 - 1;

    // Convert from HWC to CHW format
    std::vector<cv::Mat> channels;
    cv::split(resizedImage, channels);
    
    // Create output matrix with CHW layout (3x512x512)
    out = cv::Mat(3, 512 * 512, CV_32F);
    
    // Copy each channel's data into the output matrix
    for (int c = 0; c < 3; c++) {
        cv::Mat channel = channels[c].reshape(1, 1); // Flatten to 1D
        channel.copyTo(out.row(c));
    }
    
    // Reshape to 3D tensor format (3x512x512)
    out = out.reshape(1, 3);

    return out;

Honestly, this is the question that I would ask people in interviews at Adobe all the time, and I would fail people if they didn't understand why you would switch from channels first to channels last. This may seem absolutely trivial (and it is), but it is still a very important and trivial step that's applicable when we start working with Visual Langauge Models, which need to normalize an image and break it up into patches before doing a vision encoding step.

Of course, we'd then go and just run it normally in ONNX runtime:

    cv::Mat preProcesedMat = preProcess(input);
    int64_t width = preProcesedMat.cols;
    int64_t height = preProcesedMat.rows;

    LOGI("preprocessedMat cols: %d", int(width));
    LOGI("preprocessedMat rows: %d", int(height));

    // Allocate the things
    auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
    size_t input_tensor_size = preProcesedMat.total() * preProcesedMat.channels();
    std::vector<int64_t> input_node_dims = {1, 3, 512, 512};

    LOGI("input_tensor_size: %d", int(input_tensor_size));

    if(preProcesedMat.data == nullptr)
    {
        LOGI("preProcesedMat is nullptr");
    }

    auto input_tensor = Ort::Value::CreateTensor<float>(memory_info, (float *)preProcesedMat.data,
                                                        input_tensor_size,
                                                        input_node_dims.data(), 4);

    // Check to make sure these are actually correct inputs and outputs
    const char* input_names[] = {"input.1"};
    const char* output_names[] = {"921"};

    LOGI("input_names: %s", input_names[0]);
    LOGI("output_names: %s", output_names[0]);

    auto output_result = session->Run(Ort::RunOptions{nullptr}, input_names,
                                &input_tensor, 1,
                                output_names, 1);

    LOGI("We were able to run the model");

    // Use std::move to transfer ownership instead of copying
    Ort::Value output_tensor = std::move(output_result[0]);
    float * output_data = output_tensor.GetTensorMutableData<float>();

    LOGI("We made it!");

    output = postProcess(output_data);

Honestly, maintaining the basic hygiene in C++ of usign std::move is critical here, especially since allocating more memory means that you have to do a copy and that's slower. I know we're dealing with CPU memory here, and not GPU memory from an execution provider, where you'd do some other method of keeping memory on the GPU depending on the GPU backend you choose to use.

So, this code was originally written for Android, but it also worked on iOS. I know that in iOS, I had to write some Objective-C++ to pass thigns into Swift.

+ (UIImage *)processImage:(UIImage *)image {
    // Convert UIImage to cv::Mat
    cv::Mat inputMat;
    UIImageToMat(image, inputMat);
    
    if (inputMat.empty()) {
        NSLog(@"Failed to convert UIImage to cv::Mat");
        return nil;
    }
    
    // Create AnimeGan instance
    AnimeGan::AnimeGan animeGan;
    
    // Pre-process
    cv::Mat preprocessedMat = animeGan.preProcess(inputMat);
    
    // For testing, we'll just create dummy output data
    // In a real implementation, this would be the output from your ML model
    float *dummyOutputData = new float[512 * 512 * 3];
    for (int i = 0; i < 512 * 512 * 3; i++) {
        dummyOutputData[i] = 0.5f; // Just a placeholder value
    }
    
    // Post-process
    cv::Mat outputMat = animeGan.postProcess(dummyOutputData);
    
    // Clean up
    delete[] dummyOutputData;
    
    // Convert to UIImage
    UIImage *resultImage = MatToUIImage(outputMat);
    return resultImage;
}

Then in Swift:

    func processWithCoreML(_ image: UIImage) async throws -> UIImage? {
        // Check if we have the loaded model from init
        if let model = self.model {
            // Use the new direct method
            if let processedImage = model.processImage(image) {
                return processedImage
            } else {
                throw NSError(domain: "AnimeGanWrapper", code: 2, userInfo: [NSLocalizedDescriptionKey: "Model prediction failed"])
            }
        } else {
            // Fall back to the shared instance method if our direct model isn't loaded
            let model = AnimeGanModel.shared
            
            // Load the model if not already loaded
            if !model.isModelLoaded {
                try await model.loadModel()
            }
            
            // Get pixel buffer from preprocessed image and properly manage the returned object
            guard let unmanagedPixelBuffer = AnimeGanObjCWrapper.pixelBuffer(fromPreprocessedMat: image) else {
                throw NSError(domain: "AnimeGanWrapper", code: 1, userInfo: [NSLocalizedDescriptionKey: "Failed to create pixel buffer"])
            }
            
            // Convert Unmanaged<CVPixelBuffer> to CVPixelBuffer
            let pixelBuffer = unmanagedPixelBuffer.takeRetainedValue()
            
            // Run the model
            guard let outputData = try await model.processImage(pixelBuffer) else {
                throw NSError(domain: "AnimeGanWrapper", code: 2, userInfo: [NSLocalizedDescriptionKey: "Model returned nil output"])
            }
            
            // Convert [Float] to float*
            let count = outputData.count
            let outputPtr = UnsafeMutablePointer<Float>.allocate(capacity: count)
            outputData.withUnsafeBufferPointer { buffer in
                outputPtr.initialize(from: buffer.baseAddress!, count: count)
            }
            
            // Post-process
            let processedImage = AnimeGanObjCWrapper.postProcessImage(image, withOutputData: outputPtr)
            
            // Clean up
            outputPtr.deallocate()
            
            return processedImage
        }
    }

Of course, I added the slower ORT inference as a baseline reference as well:

func processWithONNXRuntime(_ image: UIImage) -> UIImage? {
        // First check if the model exists in the documents directory
        let documentsDirectory = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0]
        let modelURL = documentsDirectory.appendingPathComponent("animegan_paprika_constraned.with_runtime_opt.ort")
        
        if FileManager.default.fileExists(atPath: modelURL.path) {
            print("Using downloaded ONNX model from documents directory")
            
            // Set up the ONNX Runtime with the downloaded model
            if !AnimeGanORTWrapper.setup(withModelPath: modelURL.path) {
                print("Failed to set up ONNX Runtime with downloaded model")
                return nil
            }
        } else {
            // Fall back to bundle if available (for development/testing)
            print("Downloaded model not found, checking bundle...")
            guard let bundleModelPath = Bundle.main.path(forResource: "animegan_paprika_constraned.with_runtime_opt", ofType: "ort") else {
                print("Failed to find ONNX model in bundle or documents directory")
                return nil
            }
            
            // Set up the ONNX Runtime with the bundled model
            if !AnimeGanORTWrapper.setup(withModelPath: bundleModelPath) {
                print("Failed to set up ONNX Runtime with bundled model")
                return nil
            }
        }
        
        // Process the image
        return AnimeGanORTWrapper.processImage(image)
    }

So, that was kinda fun. Then I dealt with the most frustrating task. Doing this the Android developer way with Kotlin.

WTF? This isn't Kotlin, I thought you did Android Dev

OK, so, fun fact! You actually can't load files using ONNX Runtime that's out in the repository, so you have to do this to load the file in the first place:

        try {
            val modelPath = "/storage/emulated/0/Android/data/ai.baseweight.animegan/files/downloaded_model.onnx"
            val modelFile = File(modelPath)

            sessionOptions = OrtSession.SessionOptions()


            if (!modelFile.exists()) {
                Log.e("AnimeGan", "Model file does not exist at $modelPath")
                throw OrtException("Model file does not exist at $modelPath")
            }
            // For some reason ORT in Java can't read this file, pass it through a ByteBuffer to ORT
            val modelBytes = modelFile.readBytes()
            session = env.createSession(modelBytes, sessionOptions);
        }
        catch (e: OrtException) {
            Log.d("AnimeGan", "OrtException: " + e.message)
        }

You would have to do this anyway if you were doing something like decrypting a model and loading it, but this was frustrating, not nearly as frustrating as pre-processing an image in Kotlin using the Android SDK. I straight up hat this and this shit kinda made me hate life.

 fun preprocess(inputBuffer: ByteBuffer, outputBuffer: FloatBuffer, width: Int, height: Int) {
        // Create a bitmap from the input buffer
        val inputBitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888)
        inputBuffer.rewind()
        inputBitmap.copyPixelsFromBuffer(inputBuffer)

        // Resize to 512x512 if needed
        val resizedBitmap = if (inputBitmap.width != 512 || inputBitmap.height != 512) {
            Bitmap.createScaledBitmap(inputBitmap, 512, 512, true)
        } else {
            inputBitmap
        }

        // Convert to RGB and normalize to [-1, 1]
        val pixels = IntArray(512 * 512)
        resizedBitmap.getPixels(pixels, 0, 512, 0, 0, 512, 512)
        
        outputBuffer.rewind()
        
        // Debug: Check first few pixel values
        Log.d("AnimeGan", "First few input pixels: ${pixels.take(5).joinToString()}")
        
        // First, write all R values
        for (pixel in pixels) {
            val r = ((pixel shr 16) and 0xFF) / 127.5f - 1.0f
            outputBuffer.put(r)
        }
        
        // Then, write all G values
        for (pixel in pixels) {
            val g = ((pixel shr 8) and 0xFF) / 127.5f - 1.0f
            outputBuffer.put(g)
        }
        
        // Finally, write all B values
        for (pixel in pixels) {
            val b = (pixel and 0xFF) / 127.5f - 1.0f
            outputBuffer.put(b)
        }
        
        outputBuffer.rewind()
        
        // Debug: Check first few normalized values
        val debugValues = FloatArray(15) // Get first 5 pixels (3 channels each)
        outputBuffer.get(debugValues)
        Log.d("AnimeGan", "First few normalized values: ${debugValues.joinToString()}")
        outputBuffer.rewind()
        
        // Debug: Check min/max values
        var minValue = Float.MAX_VALUE
        var maxValue = Float.MIN_VALUE
        val allValues = FloatArray(outputBuffer.remaining())
        outputBuffer.get(allValues)
        for (value in allValues) {
            minValue = minOf(minValue, value)
            maxValue = maxOf(maxValue, value)
        }
        Log.d("AnimeGan", "Normalized value range: min=$minValue, max=$maxValue")
        outputBuffer.rewind()
    }

Of course, there's the main prediction, which is pretty boring and standard:

fun doPredict(inputBuffer: ByteBuffer, outputBuffer: ByteBuffer, width: Int, height: Int) {
        try {
            // Create buffers for preprocessing
            val preprocessedBuffer = ByteBuffer.allocateDirect(512 * 512 * 3 * 4) // 3 channels, 4 bytes per float
            val floatBuffer = preprocessedBuffer.order(ByteOrder.nativeOrder()).asFloatBuffer()
            
            // Preprocess the input
            preprocess(inputBuffer, floatBuffer, width, height)

            // TODO: Run the model inference here
            // This will be implemented in the next step
            val shape = longArrayOf(1L, 3L, 512L, 512L)

            // We can't rewind enough
            floatBuffer.rewind()

            var tensorFromBuffer = OnnxTensor.createTensor(env, floatBuffer, shape)
            val result = session.run(Collections.singletonMap("input", tensorFromBuffer), setOf("output"))
            result.use {
                // Assuming outputArray is a 4D float array from ONNX model
                val outputArray = result.get(0).value as Array<Array<Array<FloatArray>>>

                // get the values of the 4D array
                val batchSize = outputArray.size
                val channels = outputArray[0].size
                val imageHeight = outputArray[0][0].size
                val imageWidth = outputArray[0][0][0].size

                // Calculate the total number of elements
                val totalElements = batchSize * channels * imageHeight * imageWidth

                // Create a FloatBuffer to hold the flattened data
                val outBuffer = FloatBuffer.allocate(totalElements)

                // Iterate through the 4D array and put each element into the FloatBuffer
                for (b in 0 until batchSize) {
                    for (c in 0 until channels) {
                        for (h in 0 until imageHeight) {
                            for (w in 0 until imageWidth) {
                                outBuffer.put(outputArray[b][c][h][w])
                            }
                        }
                    }
                }

                // Rewind the FloatBuffer to the beginning
                outBuffer.rewind()
                postprocess(outBuffer, outputBuffer)
            }


        }
        catch (e: OrtException) {
            Log.d("AnimeGan", "OrtException: " + e.message)
        }
    }

Honestly, doing AI in Android in Kotlin makes you want to hate life. But hey, let's go straight in to weird and throw Rust in the mix, because Rust is cool and we can use things like Candle. Except we can't because while we can run Candle on Android, we can't load the ONNX file because it has a Reshape, and Candle doesn't like it when things reshape.

So....Let's use pyke-ort, just to say we did use Rust:

I did this over here in this repo, and it was non-trivial, but I mean, look at this hawtness:


    fn preprocess(image: DynamicImage) -> AnyhowResult<Vec<f32>> {
        // Resize to 512x512
        let resized = image.resize_exact(512, 512, image::imageops::FilterType::CatmullRom);
        
        // Convert to RGB and normalize to [-1, 1]
        let rgb = resized.to_rgb8();
        let mut data = Vec::with_capacity(3 * 512 * 512);
        
        // Convert from HWC to CHW format and normalize
        for c in 0..3 {
            for y in 0..512 {
                for x in 0..512 {
                    let pixel = rgb.get_pixel(x, y);
                    // Normalize to [-1, 1]
                    // Note: RgbImage stores channels in RGB order, so we can use c directly
                    data.push((pixel[c] as f32 / 255.0) * 2.0 - 1.0);
                }
            }
        }
        
        Ok(data)
    }

The main purpose of this was so I could get used to bringing over Rust components to Android like Hugging Face Tokenizers, as well as keep my pre-processing logic working in Rust, since it's nice and clean in comparison to doing this in raw C++. Is Rust good for everything? Well, given that I'm currently bashing my head on a SmolVLM demo using pyke-ort, and begging Claude-4-sonnet in Cursor to fill in the gaps in my Rust knoweldge, No, no it is definitely not. But for pre-processing, yeah, this is actually pretty great.

So, what did you learn

Many weird and useless things

  • There are many ways to run ONNX models on Rust
  • Kotlin sucks at image processing, Rust is awesome at it
  • AnimeGAN is a terrible demo to show a VC
  • This post being basic is also weird, and probably too technical for most AI developers
  • Executorch is kinda neat, but needs better C++ and Rust support
  • LiteRT hasn't changed very much

The majority of the old ML functionality that people do is some variation of this basic structure:

  • Pre-Processing
  • Inference
  • Post-Processing

In fact, that's what transformers and transformers.js does, and most people just use it and don't think about the internals inside. There are entire cottage industries of people who just sell entire libraries that just do one thing, whether it be the quest for an even faster whisper (which is what ArgMax and PicoVoice do), or full on ML pipelines like what NimbleEdge does. There's also the entire empire that Ultralytics built with their GPLv3 licenced YOLO models that everyone uses in their object dectection workflow.

In Conclusion

It's good to have a nice, small test of known values when you're switching out tooling and frameworks to make sure that you're still on the right track when doing this sort of work. I really hate using Kotlin or any JVM based framework for image processing and people who think that this is the right choice of the job probably either don't have the option of using Rust or C++, or are scared by using things that are super strict or are just plain unsafe.

That being said, at least none of this is Javascript or Typescript.