Eccentric Developments


Use WebAssembly to Speedup DOD

Well well well, remember my last entry where I mentioned using wasm to improve the data-oriented implementation speed? It ended up being way more complicated than expected and probably not worth the effort; in any case I learned some interesting tidbits of information:

  1. It is a very bad idea to create closures right in the middle of your hot path.
  2. Webassembly speed varies widly between browsers. Interestingly, Safari runs webassembly faster than Firefox and Chrome.
  3. Next time, remember not to count webassembly loading/instantiation into the final running time.
  4. Did I mention NO CLOSURES IN THE HOT PATH? this one really got me.
  5. Calling webassembly is costly, so either make it do as much as possible, or call it only when necessary. This is why I ended up using DataViews to access memory.
  6. And potentially the most important: never expect an easy win.

In the end, I was able to make the webassembly-using version almost as fast a the regular one, using a bunch of micro optimizations, see them below:

No webassembly:

function createTriangleIntersectFunction(args) {
  const {
    vector: { add, scale, sub, permute, maxDimension, abs },
  } = args;
  const triangleIntersect = (ray, triangle) => {
    let pt0_t = sub(triangle.pt0, ray.origin);
    let pt1_t = sub(triangle.pt1, ray.origin);
    let pt2_t = sub(triangle.pt2, ray.origin);
    const k = maxDimension(abs(ray.origin));
    const i = (k + 1) % 3;
    const j = (k + 2) % 3;

    const pd = permute(ray.direction, i, j, k);
    const sz = 1.0 / pd[2];
    const sx = -pd[0] * sz;
    const sy = -pd[1] * sz;

    pt0_t = permute(pt0_t, i, j, k);
    pt1_t = permute(pt1_t, i, j, k);
    pt2_t = permute(pt2_t, i, j, k);

    const pt0_t_0 = pt0_t[0] + sx * pt0_t[2];
    const pt0_t_1 = pt0_t[1] + sy * pt0_t[2];
    const pt0_t_2 = pt0_t[2] * sz;
    const pt1_t_0 = pt1_t[0] + sx * pt1_t[2];
    const pt1_t_1 = pt1_t[1] + sy * pt1_t[2];
    const pt1_t_2 = pt1_t[2] * sz;
    const pt2_t_0 = pt2_t[0] + sx * pt2_t[2];
    const pt2_t_1 = pt2_t[1] + sy * pt2_t[2];
    const pt2_t_2 = pt2_t[2] * sz;

    const e0 = pt1_t_0 * pt2_t_1 - pt1_t_1 * pt2_t_0;
    const e1 = pt2_t_0 * pt0_t_1 - pt2_t_1 * pt0_t_0;
    const e2 = pt0_t_0 * pt1_t_1 - pt0_t_1 * pt1_t_0;

    if (
      (e0 < 0.0 || e1 < 0.0 || e2 < 0.0) &&
      (e0 > 0.0 || e1 > 0.0 || e2 > 0.0)
    ) {
      return { hit: false };
    }

    const det = e0 + e1 + e2;
    if (det === 0.0) {
      return { hit: false };
    }

    const t_scaled = e0 * pt0_t_2 + e1 * pt1_t_2 + e2 * pt2_t_2;


    const t = t_scaled / det;

    if (t > 0.007) {
      const point = add(scale(ray.direction, t), ray.origin);
      return {
        hit: true,
        distance: t,
        point,
        normal: triangle.normal,
      };
    }

    return {
      hit: false,
    };
  };

  return {
    triangleIntersect,
  };
}

function createScene(args) {
  const {
    vector: { sub, unit, cross },
  } = args;
  return {
    scene: [
      {
        center: [0, -10002, 0],
        radius: 9999.0,
        color: [1.0, 1.0, 1.0],
        isLight: false,
        isSphere: true,
      },
      {
        center: [-10012, 0, 0],
        radius: 9999.0,
        color: [1, 0, 0],
        isLight: false,
        isSphere: true,
      },
      {
        center: [10012, 0, 0],
        radius: 9999.0,
        color: [0, 1, 0],
        isLight: false,
        isSphere: true,
      },
      {
        center: [0, 0, 10012],
        radius: 9999.0,
        color: [1, 1, 1],
        isLight: false,
        isSphere: true,
      },
      {
        center: [0, 10012, 0],
        radius: 9999.0,
        color: [1, 1, 1],
        isLight: true,
        isSphere: true,
      },
      {
        center: [-5, 0, 2],
        radius: 2.0,
        color: [1, 1, 0],
        isLight: false,
        isSphere: true,
      },
      {
        center: [0, 5, -1],
        radius: 4.0,
        color: [1, 0, 0],
        isLight: false,
        isSphere: true,
      },
      {
        center: [8, 5, -1],
        radius: 2,
        color: [0, 0, 1],
        isLight: false,
        isSphere: true,
      },
    ].map((obj) => {
      if (obj.isTriangle) {
        const edge0 = sub(obj.pt1, obj.pt0);
        const edge1 = sub(obj.pt2, obj.pt0);
        obj.normal = unit(cross(edge0, edge1));
      }
      return obj;
    }),
  };
}

function createRandomDirectionFunction(args) {
  const {
    vector: { unit, dot },
  } = args;
  const randomDirection = (normal) => {
    while (true) {
      let p = unit([2 * Math.random() - 1, 2 * Math.random() - 1, 2 * Math.random() - 1]);
      if (dot(p, normal) >= 0) {
        return p;
      }
    }
  };
  return {
    randomDirection,
  };
}

function createSphereIntersectFunction(args) {
  const {
    vector: { sub, dot, scale, unit, add },
  } = args;
  const sphereIntersect = (ray, sphere) => {
    const oc = sub(ray.origin, sphere.center);
    const a = dot(ray.direction, ray.direction);
    const b = dot(oc, ray.direction);
    const c = dot(oc, oc) - sphere.radius * sphere.radius;
    const dis = b * b - a * c;

    if (dis > 0) {
      const e = Math.sqrt(dis);
      let t = (-b - e) / a;
      if (t > 0.007) {
        const point = add(scale(ray.direction, t), ray.origin);
        return {
          hit: true,
          distance: t,
          point,
          // This is the new code to calculate the normal
          normal: unit(sub(point, sphere.center)),
        };
      }

      t = (-b + e) / a;
      if (t > 0.007) {
        const point = add(scale(ray.direction, t), ray.origin);
        return {
          hit: true,
          distance: t,
          point,
          // This is the new code to calculate the normal
          normal: unit(sub(point, sphere.center)),
        };
      }
    }
    return {
      hit: false,
    };
  };
  return {
    sphereIntersect,
  };
}

function createAspectRatioFunction() {
  return {
    aspectRatio: (width, height) => {
      let gcd = width;
      let reminder = height;
      while (reminder != 0) {
        const temp = reminder;
        reminder = gcd % reminder;
        gcd = temp;
      }
      return [width / gcd, height / gcd];
    },
  };
}

function createCamera(args) {
  const { width, height, aspectRatio } = args;
  const [w, h] = aspectRatio(width, height);
  return {
    camera: {
      leftTop: [-w, h + 1, -50.0],
      rightTop: [w, h + 1, -50.0],
      leftBottom: [-w, -h + 1, -50.0],
      eye: [0.0, 0.0, -65.0],
    },
  };
}

function createImageGeometry({ width, height }) {
  return {
    imageGeometry: {
      width,
      height,
    },
  };
}

function createVectorFunction() {
  const sub = (A, B) => [A[0] - B[0], A[1] - B[1], A[2] - B[2]];
  const add = (A, B) => [A[0] + B[0], A[1] + B[1], A[2] + B[2]];
  const mul = (A, B) => [A[0] * B[0], A[1] * B[1], A[2] * B[2]];
  const dot = (A, B) => A[0] * B[0] + A[1] * B[1] + A[2] * B[2];
  const scale = (A, s) => [A[0] * s, A[1] * s, A[2] * s];
  const norm = (A) => Math.sqrt(dot(A, A));
  const unit = (A) => scale(A, 1.0 / norm(A));
  const abs = (A) => [Math.abs(A[0]), Math.abs(A[1]), Math.abs(A[2])];
  const maxDimension = (A) => {
    if (A[0] > A[1] && A[0] > A[2]) return 0;
    if (A[1] > A[0] && A[1] > A[3]) return 1;
    return 2;
  };
  const permute = (A, i, j, k) => [A[i], A[j], A[k]];
  const cross = (A, B) => {
    const j = A[1] * B[2] - B[1] * A[2];
    const k = A[2] * B[0] - A[0] * B[2];
    const l = A[0] * B[1] - A[1] * B[0];
    return [j, k, l];
  };
  const vector = {
    sub,
    add,
    mul,
    dot,
    scale,
    norm,
    unit,
    abs,
    maxDimension,
    permute,
    cross,
  };

  return {
    vector,
  };
}

function calculatePrimaryRays(args) {
  const {
    camera: { rightTop, leftTop, leftBottom, eye },
    imageGeometry: { width, height },
    vector: { scale, add, sub, unit },
  } = args;
  const vdu = scale(sub(rightTop, leftTop), 1.0 / width);
  const vdv = scale(sub(leftBottom, leftTop), 1.0 / height);
  const primaryRays = [];

  for (let y = 0; y < height; y++) {
    for (let x = 0; x < width; x++) {
      const pixel = y * width + x;
      const origin = eye;
      const direction = unit(
        sub(add(add(scale(vdu, x), scale(vdv, y)), leftTop), origin)
      );
      primaryRays.push({
        pixel,
        origin,
        direction,
      });
    }
  }

  return {
    primaryRays,
  };
}

function createTraceFunction(args) {
  const { scene, sphereIntersect, triangleIntersect } = args;
  const trace = (ray) => {
    let closestHit = { hit: false, distance: Number.MAX_VALUE };
    for (obj of scene) {
      const fn = obj.isTriangle ? triangleIntersect : sphereIntersect;
      const res = fn(ray, obj);
      if (!res.hit) continue;
      if (res.distance >= closestHit.distance) continue;
      closestHit = res;
      closestHit.obj = obj;
    }
    return closestHit;
  };
  return {
    trace,
  };
}

function tracePrimaryRays(args) {
  const { trace, primaryRays } = args;
  const traceResults = [];
  for (const ray of primaryRays) {
    traceResults.push({
      ray,
      intersection: trace(ray),
    });
  }

  return {
    traceResults,
  };
}

function generateBitmap(args) {
  const {
    traceResults,
    shading,
    vector: { mul },
  } = args;

  const bitmap = [];
  let idx = 0;
  for (const result of traceResults) {
    const it = result.intersection;
    let pixel = [0, 0, 0];
    if (it.hit) {
      pixel = it.obj.color;
      if (!it.obj.isLight) {
        const intensity = shading(it.point, it.normal, 0);
        pixel = mul(pixel, intensity);
      }
    }
    bitmap[idx++] = pixel;
  }

  return {
    bitmap,
  };
}

function pipeline(fns) {
  return (args) => {
    let acc = { ...args };
    for (const fn of fns) {
      const result = fn(acc);
      acc = { ...acc, ...result };
    }
    return acc;
  };
}

function createShadingFunction(args) {
  const {
    vector: { add, scale, mul, dot },
    trace,
    randomDirection,
  } = args;
  const shading = (shadingPoint, pointNormal, depth) => {
    if (depth === 5) {
      return [0, 0, 0];
    }
    const origin = add(shadingPoint, scale(pointNormal, 0.01));
    const direction = randomDirection(pointNormal);
    const d = dot(pointNormal, direction);
    const ray = { origin, direction };
    const tr = trace(ray);
    if (!tr.hit) {
      return [0, 0, 0];
    }
    if (tr.obj.isLight) {
      return scale(tr.obj.color, d);
    }
    return mul(tr.obj.color, scale(shading(tr.point, tr.normal, depth + 1), d));
  };

  return {
    shading,
  };
}

function createRandomFunction() {
  let x = 123456789;
  let y = 362436069;
  let z = 521288629;
  let w = 88675123;
  let t = 0;
  let max = 4294967295;

  const random = () => {
    t = (x ^ (x << 11)) >>> 0;
    x = y;
    y = z;
    z = w;
    w = (w ^ (w >>> 19) ^ (t ^ (t >>> 8))) >>> 0;
    return w / max;
  };

  return { random };
}

var renderingPipeline = pipeline([
  createVectorFunction,
  createAspectRatioFunction,
  createScene,
  createCamera,
  createImageGeometry,
  createRandomFunction,
  createRandomDirectionFunction,
  calculatePrimaryRays,
  createSphereIntersectFunction,
  createTriangleIntersectFunction,
  createTraceFunction,
  createShadingFunction,
  tracePrimaryRays,
  generateBitmap,
]);

const canvas = document.getElementById('canvas-1');
const ctx = canvas.getContext('2d');
const width = canvas.width;
const height = canvas.height;

console.time('Render');
const result = renderingPipeline({ width, height });
const { bitmap } = result;
const finalBitmap = new Uint32Array(width * height);
for (let i = 0; i < bitmap.length; i++) {
  finalBitmap[i] =
    (255 << 24) |
    ((bitmap[i][2] * 255) << 16) |
    ((bitmap[i][1] * 255) << 8) |
    (bitmap[i][0] * 255);
}

const imageData = new ImageData(
  new Uint8ClampedArray(finalBitmap.buffer),
  width
);
ctx.putImageData(imageData, 0, 0);
console.timeEnd('Render')

With WebAssembly:

To get a more accurate measurement of the running time, check the console for the Render and WasmLoad messages.

async function loadWasm(wasmFile) {
  const { instance: { exports: wasm } } = await WebAssembly.instantiateStreaming(fetch(wasmFile), {});
  return wasm;
}

function createTriangleIntersectFunction(args) {
  const {
    vector: { add, scale, sub, permute, maxDimension, abs },
  } = args;
  const triangleIntersect = (ray, triangle) => {
    let pt0_t = sub(triangle.pt0, ray.origin);
    let pt1_t = sub(triangle.pt1, ray.origin);
    let pt2_t = sub(triangle.pt2, ray.origin);
    const k = maxDimension(abs(ray.origin));
    const i = (k + 1) % 3;
    const j = (k + 2) % 3;

    const pd = permute(ray.direction, i, j, k);
    const sz = 1.0 / pd[2];
    const sx = -pd[0] * sz;
    const sy = -pd[1] * sz;

    pt0_t = permute(pt0_t, i, j, k);
    pt1_t = permute(pt1_t, i, j, k);
    pt2_t = permute(pt2_t, i, j, k);

    const pt0_t_0 = pt0_t[0] + sx * pt0_t[2];
    const pt0_t_1 = pt0_t[1] + sy * pt0_t[2];
    const pt0_t_2 = pt0_t[2] * sz;
    const pt1_t_0 = pt1_t[0] + sx * pt1_t[2];
    const pt1_t_1 = pt1_t[1] + sy * pt1_t[2];
    const pt1_t_2 = pt1_t[2] * sz;
    const pt2_t_0 = pt2_t[0] + sx * pt2_t[2];
    const pt2_t_1 = pt2_t[1] + sy * pt2_t[2];
    const pt2_t_2 = pt2_t[2] * sz;

    const e0 = pt1_t_0 * pt2_t_1 - pt1_t_1 * pt2_t_0;
    const e1 = pt2_t_0 * pt0_t_1 - pt2_t_1 * pt0_t_0;
    const e2 = pt0_t_0 * pt1_t_1 - pt0_t_1 * pt1_t_0;

    if (
      (e0 < 0.0 || e1 < 0.0 || e2 < 0.0) &&
      (e0 > 0.0 || e1 > 0.0 || e2 > 0.0)
    ) {
      return { hit: false };
    }

    const det = e0 + e1 + e2;
    if (det === 0.0) {
      return { hit: false };
    }

    const t_scaled = e0 * pt0_t_2 + e1 * pt1_t_2 + e2 * pt2_t_2;


    const t = t_scaled / det;

    if (t > 0.007) {
      const point = add(scale(ray.direction, t), ray.origin);
      return {
        hit: true,
        distance: t,
        point,
        normal: triangle.normal,
      };
    }

    return {
      hit: false,
    };
  };

  return {
    triangleIntersect,
  };
}

function createScene(args) {
  const {
    vector: { sub, unit, cross },
    memory: { allocStaticFloat32Array }
  } = args;
  const scene = [
    {
      center: [0, -10002, 0],
      radius: 9999.0,
      color: [1.0, 1.0, 1.0],
      isLight: false,
      isSphere: true,
    },
    {
      center: [-10012, 0, 0],
      radius: 9999.0,
      color: [1, 0, 0],
      isLight: false,
      isSphere: true,
    },
    {
      center: [10012, 0, 0],
      radius: 9999.0,
      color: [0, 1, 0],
      isLight: false,
      isSphere: true,
    },
    {
      center: [0, 0, 10012],
      radius: 9999.0,
      color: [1, 1, 1],
      isLight: false,
      isSphere: true,
    },
    {
      center: [0, 10012, 0],
      radius: 9999.0,
      color: [1, 1, 1],
      isLight: true,
      isSphere: true,
    },
    {
      center: [-5, 0, 2],
      radius: 2.0,
      color: [1, 1, 0],
      isLight: false,
      isSphere: true,
    },
    {
      center: [0, 5, -1],
      radius: 4.0,
      color: [1, 0, 0],
      isLight: false,
      isSphere: true,
    },
    {
      center: [8, 5, -1],
      radius: 2,
      color: [0, 0, 1],
      isLight: false,
      isSphere: true,
    },
  ].map((obj, id) => {
    if (obj.isTriangle) {
      const edge0 = sub(obj.pt1, obj.pt0);
      const edge1 = sub(obj.pt2, obj.pt0);
      obj.normal = unit(cross(edge0, edge1));
    }
    obj.id = id;
    return obj;
  });

  const spheres = scene.filter(obj => obj.isSphere);
  const count = spheres.length;
  const centerX = allocStaticFloat32Array(count);
  const centerY = allocStaticFloat32Array(count);
  const centerZ = allocStaticFloat32Array(count);
  const radius = allocStaticFloat32Array(count);
  const id = [];

  spheres.forEach((s, i) => {
    centerX.set(i, s.center[0]);
    centerY.set(i, s.center[1]);
    centerZ.set(i, s.center[2]);
    radius.set(i, s.radius);
    id[i] = i;
  })

  const spheresVector = {
    centerX,
    centerY,
    centerZ,
    radius,
    id,
  }
  return {
    scene,
    spheresVector
  };
}

function createRandomDirectionFunction(args) {
  const {
    vector: { unit, dot },
  } = args;
  const randomDirection = (normal) => {
    while (true) {
      let p = unit([2 * Math.random() - 1, 2 * Math.random() - 1, 2 * Math.random() - 1]);
      if (dot(p, normal) >= 0) {
        return p;
      }
    }
  };
  return {
    randomDirection,
  };
}

function createMemoryFunctions(args) {
  const { wasm } = args;
  const totalAvailable = 1024 * 4;
  const dynamicMemPtr = wasm.alloc(totalAvailable);
  const staticMemPtr = wasm.alloc(totalAvailable);
  const dynView = new DataView(wasm.memory.buffer, dynamicMemPtr, totalAvailable);
  const staView = new DataView(wasm.memory.buffer, staticMemPtr, totalAvailable);

  function dynGet(idx) { return dynView.getFloat32(this.ptr + idx*4, true); }
  function dynSet(idx, v) { dynView.setFloat32(this.ptr + idx * 4, v, true); }
  let dynamicUsed = 0;
  const allocFloat32Array = (size) => {
    const ptr = dynamicUsed * 4;
    const byteOffset = dynamicMemPtr + ptr;
    dynamicUsed += size;
    return {
      length: size,
      byteOffset,
      ptr,
      get: dynGet,
      set: dynSet,
    }
  };


  function staGet(idx) { return staView.getFloat32(this.ptr + idx*4, true); }
  function staSet(idx, v) { staView.setFloat32(this.ptr + idx * 4, v, true); }
  let staticUsed = 0;
  const allocStaticFloat32Array = (size) => {
    const ptr = staticUsed * 4;
    const byteOffset = staticMemPtr + ptr;
    staticUsed += size;
    return {
      length: size,
      byteOffset,
      ptr,
      get: staGet,
      set: staSet,
    }
  }
  const free = () => dynamicUsed = 0;
  return {
    memory: {
      allocFloat32Array,
      allocStaticFloat32Array,
      free,
    },
  }
}

function createSphereIntersectSIMDFunction(args) {
  const {
    simd: { add, fill, sub, mulAdd, mul, sqrt, min, div },
    vector,
    memory: {
      allocFloat32Array,
      free,
    }
  } = args;
  const sphereIntersectSIMD = (ray, spheresVector) => {
    const len = spheresVector.id.length;
    const OCX = allocFloat32Array(len); fill(ray.origin[0], OCX);
    const OCY = allocFloat32Array(len); fill(ray.origin[1], OCY);
    const OCZ = allocFloat32Array(len); fill(ray.origin[2], OCZ);
    const rayDirectionX = allocFloat32Array(len); fill(ray.direction[0], rayDirectionX);
    const rayDirectionY = allocFloat32Array(len); fill(ray.direction[1], rayDirectionY);
    const rayDirectionZ = allocFloat32Array(len); fill(ray.direction[2], rayDirectionZ);
    sub(OCX, spheresVector.centerX, OCX);
    sub(OCY, spheresVector.centerY, OCY);
    sub(OCZ, spheresVector.centerZ, OCZ);
    const A = allocFloat32Array(len); fill(vector.dot(ray.direction, ray.direction), A);
    const B = allocFloat32Array(len); fill(0, B);
    mulAdd(OCX, rayDirectionX, B);
    mulAdd(OCY, rayDirectionY, B);
    mulAdd(OCZ, rayDirectionZ, B);
    const MO = rayDirectionY; fill(-1, MO);
    mul(B, MO, B);
    const C = rayDirectionZ; fill(0, C);
    mulAdd(OCX, OCX, C);
    mulAdd(OCY, OCY, C);
    mulAdd(OCZ, OCZ, C);

    const MR = MO; mul(spheresVector.radius, spheresVector.radius, MR)
    sub(C, MR, C);
    const BB = OCX; mul(B, B, BB);
    const AC = OCY; mul(A, C, AC);
    const DIS = OCZ; sub(BB, AC, DIS);
    const MASK = MR;
    let possibleHits = 0;
    for (let i = 0; i < len; i++) {
      const v = DIS.get(i) > 0 ? 1 : 0;
      MASK.set(i, v);
          possibleHits += v;
    }
    if (possibleHits === 0) {
      return MASK;
    }

    mul(DIS, MASK, DIS); // Prevent NaNs
    const E = DIS; sqrt(DIS, E);
    const T1 = BB; sub(B, E, T1);
    const T2 = AC; add(B, E, T2);
    div(T1, A, T1);
    div(T2, A, T2);

    min(T1, T2, T1)
    mul(T1, MASK, T1);
    free();
    return T1;
  };
  return {
    sphereIntersectSIMD,
  };
}

function createSphereIntersectFunction(args) {
  const {
    vector: { sub, dot, scale, unit, add },
  } = args;
  const sphereIntersect = (ray, sphere) => {
    const oc = sub(ray.origin, sphere.center);
    const a = dot(ray.direction, ray.direction);
    const b = dot(oc, ray.direction);
    const c = dot(oc, oc) - sphere.radius * sphere.radius;
    const dis = b * b - a * c;

    if (dis > 0) {
      const e = Math.sqrt(dis);
      let t = (-b - e) / a;
      if (t > 0.007) {
        const point = add(scale(ray.direction, t), ray.origin);
        return {
          hit: true,
          distance: t,
          point,
          // This is the new code to calculate the normal
          normal: unit(sub(point, sphere.center)),
        };
      }

      t = (-b + e) / a;
      if (t > 0.007) {
        const point = add(scale(ray.direction, t), ray.origin);
        return {
          hit: true,
          distance: t,
          point,
          // This is the new code to calculate the normal
          normal: unit(sub(point, sphere.center)),
        };
      }
    }
    return {
      hit: false,
    };
  };
  return {
    sphereIntersect,
  };
}

function createAspectRatioFunction() {
  return {
    aspectRatio: (width, height) => {
      let gcd = width;
      let reminder = height;
      while (reminder != 0) {
        const temp = reminder;
        reminder = gcd % reminder;
        gcd = temp;
      }
      return [width / gcd, height / gcd];
    },
  };
}

function createCamera(args) {
  const { width, height, aspectRatio } = args;
  const [w, h] = aspectRatio(width, height);
  return {
    camera: {
      leftTop: [-w, h + 1, -50.0],
      rightTop: [w, h + 1, -50.0],
      leftBottom: [-w, -h + 1, -50.0],
      eye: [0.0, 0.0, -65.0],
    },
  };
}

function createImageGeometry({ width, height }) {
  return {
    imageGeometry: {
      width,
      height,
    },
  };
}

function createSIMDFunctions(args) {
  const { wasm } = args;
  const add = (A, B, OUT) => {
    const len = OUT.length;
    wasm.add(len, A.byteOffset, B.byteOffset, OUT.byteOffset);
  };
  const fill = (b, OUT) => {
    const len = OUT.length;
    wasm.fill(len, b, OUT.byteOffset);
  };
  const sub = (A, B, OUT) => {
    const len = OUT.length;
    wasm.sub(len, A.byteOffset, B.byteOffset, OUT.byteOffset);
  };
  const mulAdd = (A, B, OUT) => {
    const len = OUT.length;
    wasm.mul_add(len, A.byteOffset, B.byteOffset, OUT.byteOffset);
  };
  const sqrt = (A, OUT) => {
    const len = OUT.length;
    wasm.sqrt(len, A.byteOffset, OUT.byteOffset);
  };
  const scale = (A, b, OUT) => {
    const len = OUT.length;
    wasm.scale(len, b, A.byteOffset, OUT.byteOffset);
  };
  const div = (A, B, OUT) => {
    const len = OUT.length;
    wasm.div(len, A.byteOffset, B.byteOffset, OUT.byteOffset);
  };
  const mul = (A, B, OUT) => {
    const len = OUT.length;
    wasm.mul(len, A.byteOffset, B.byteOffset, OUT.byteOffset);
  }
  const min = (A, B, OUT) => {
    const len = OUT.length;
    // for (let i = 0; i < len; i++) {
    //   OUT[i] = A[i] < B[i] ? A[i] : B[i];
    // }
    wasm.min(len, A.byteOffset, B.byteOffset, OUT.byteOffset);
  }

  return {
    simd: {
      add,
      fill,
      sub,
      mulAdd,
      sqrt,
      div,
      mul,
      min,
      scale,
    },
  };
}

function createVectorFunctions() {
  const sub = (A, B) => [A[0] - B[0], A[1] - B[1], A[2] - B[2]];
  const add = (A, B) => [A[0] + B[0], A[1] + B[1], A[2] + B[2]];
  const mul = (A, B) => [A[0] * B[0], A[1] * B[1], A[2] * B[2]];
  const dot = (A, B) => A[0] * B[0] + A[1] * B[1] + A[2] * B[2];
  const scale = (A, s) => [A[0] * s, A[1] * s, A[2] * s];
  const norm = (A) => Math.sqrt(dot(A, A));
  const unit = (A) => scale(A, 1.0 / norm(A));
  const abs = (A) => [Math.abs(A[0]), Math.abs(A[1]), Math.abs(A[2])];
  const maxDimension = (A) => {
    if (A[0] > A[1] && A[0] > A[2]) return 0;
    if (A[1] > A[0] && A[1] > A[3]) return 1;
    return 2;
  };
  const permute = (A, i, j, k) => [A[i], A[j], A[k]];
  const cross = (A, B) => {
    const j = A[1] * B[2] - B[1] * A[2];
    const k = A[2] * B[0] - A[0] * B[2];
    const l = A[0] * B[1] - A[1] * B[0];
    return [j, k, l];
  };
  const vector = {
    sub,
    add,
    mul,
    dot,
    scale,
    norm,
    unit,
    abs,
    maxDimension,
    permute,
    cross,
  };

  return {
    vector,
  };
}

function calculatePrimaryRays(args) {
  const {
    camera: { rightTop, leftTop, leftBottom, eye },
    imageGeometry: { width, height },
    vector: { scale, add, sub, unit },
  } = args;
  const vdu = scale(sub(rightTop, leftTop), 1.0 / width);
  const vdv = scale(sub(leftBottom, leftTop), 1.0 / height);
  const primaryRays = [];

  for (let y = 0; y < height; y++) {
    for (let x = 0; x < width; x++) {
      const pixel = y * width + x;
      const origin = eye;
      const direction = unit(
        sub(add(add(scale(vdu, x), scale(vdv, y)), leftTop), origin)
      );
      primaryRays.push({
        pixel,
        origin,
        direction,
      });
    }
  }

  return {
    primaryRays,
  };
}

function createTraceSIMDFunction(args) {
  const { scene, spheresVector, sphereIntersectSIMD, vector: { add, scale, unit, sub } } = args;
  const len = spheresVector.id.length;
  const trace = (ray) => {
    const distances = sphereIntersectSIMD(ray, spheresVector);
    let closestIndex = -1;
    let closestDistance = Number.MAX_VALUE;
    for (let i = 0; i < len; i++) {
      const distance = distances.get(i);
      if (distance > 0 && distance < closestDistance) {
        closestDistance = distance;
        closestIndex = i;
      }
    }
    if (closestIndex > -1) {
      const sphere = scene[closestIndex];
      const point = add(scale(ray.direction, closestDistance), ray.origin);
      const normal = unit(sub(point, sphere.center));
      return {
        hit: true,
        distance: closestDistance,
        point,
        obj: sphere,
        normal,
      }
    }
    return {
      hit: false
    };
  }
  return {
    trace,
  };
}

function createTraceFunction(args) {
  const { scene, sphereIntersect, triangleIntersect } = args;
  const trace = (ray) => {
    let closestHit = { hit: false, distance: Number.MAX_VALUE };
    for (obj of scene) {
      const fn = obj.isTriangle ? triangleIntersect : sphereIntersect;
      const res = fn(ray, obj);
      if (!res.hit) continue;
      if (res.distance >= closestHit.distance) continue;
      closestHit = res;
      closestHit.obj = obj;
    }
    return closestHit;
  };
  return {
    trace,
  };
}

function tracePrimaryRays(args) {
  const { trace, primaryRays } = args;
  const traceResults = [];
  for (const ray of primaryRays) {
    traceResults.push({
      ray,
      intersection: trace(ray),
    });
  }

  return {
    traceResults,
  };
}

function generateBitmap(args) {
  const {
    traceResults,
    shading,
    vector: { mul },
  } = args;

  const bitmap = [];
  let idx = 0;
  for (const result of traceResults) {
    const it = result.intersection;
    let pixel = [0, 0, 0];
    if (it.hit) {
      pixel = it.obj.color;
      if (!it.obj.isLight) {
        const intensity = shading(it.point, it.normal, 0);
        pixel = mul(pixel, intensity);
      }
    }
    bitmap[idx++] = pixel;
  }

  return {
    bitmap,
  };
}

function pipeline(fns) {
  return (args) => {
    let acc = { ...args };
    for (const fn of fns) {
      const result = fn(acc);
      acc = { ...acc, ...result };
    }
    return acc;
  };
}

function createShadingFunction(args) {
  const {
    vector: { add, scale, mul, dot },
    trace,
    randomDirection,
  } = args;
  const shading = (shadingPoint, pointNormal, depth) => {
    if (depth === 5) {
      return [0, 0, 0];
    }
    const origin = add(shadingPoint, scale(pointNormal, 0.01));
    const direction = randomDirection(pointNormal);
    const d = dot(pointNormal, direction);
    const ray = { origin, direction };
    const tr = trace(ray);
    if (!tr.hit) {
      return [0, 0, 0];
    }
    if (tr.obj.isLight) {
      return scale(tr.obj.color, d);
    }
    return mul(tr.obj.color, scale(shading(tr.point, tr.normal, depth + 1), d));
  };

  return {
    shading,
  };
}

function createRandomFunction() {
  let x = 123456789;
  let y = 362436069;
  let z = 521288629;
  let w = 88675123;
  let t = 0;
  let max = 4294967295;

  const random = () => {
    t = (x ^ (x << 11)) >>> 0;
    x = y;
    y = z;
    z = w;
    w = (w ^ (w >>> 19) ^ (t ^ (t >>> 8))) >>> 0;
    return w / max;
  };

  return { random };
}

var renderingPipeline = pipeline([
  createMemoryFunctions,
  createVectorFunctions,
  createSIMDFunctions,
  createAspectRatioFunction,
  createScene,
  createCamera,
  createImageGeometry,
  createRandomFunction,
  createRandomDirectionFunction,
  calculatePrimaryRays,
  createSphereIntersectFunction,
  createSphereIntersectSIMDFunction,
  createTriangleIntersectFunction,
  createTraceSIMDFunction,
  // createTraceFunction,
  createShadingFunction,
  tracePrimaryRays,
  generateBitmap,
]);

(async () => {
  const canvas = document.getElementById('canvas-1');
  const ctx = canvas.getContext('2d');
  const width = canvas.width;
  const height = canvas.height;
  console.time("WasmLoad");
  const wasm = await loadWasm('wasm/vector.wasm');
  console.timeEnd("WasmLoad");
  console.time('Render');
  const result = await renderingPipeline({ wasm, width, height });
  const { bitmap } = result;
  const finalBitmap = new Uint32Array(width * height);
  for (let i = 0; i < bitmap.length; i++) {
    finalBitmap[i] =
      (255 << 24) |
      ((bitmap[i][2] * 255) << 16) |
      ((bitmap[i][1] * 255) << 8) |
      (bitmap[i][0] * 255);
  }

  const imageData = new ImageData(
    new Uint8ClampedArray(finalBitmap.buffer),
    width
  );
  ctx.putImageData(imageData, 0, 0);
  console.timeEnd('Render');
})();

Summary

After all the changes made to the implementation, I was expecting the data-oriented code version to be a bit faster, but the speedup never came. And thinking a bit about it, it does make sense, data-oriented design helps with cache and RAM access, but this implementation has such an small number of objects in the scene that there is really no data access issue.

One thing that I did not do in this implementation is use WebAssembly SIMD instructions. SIMD will be the last optimization before moving to algorithmic improvements.

Extras

If you are curious, this is the code behind the vector.wasm library:

use core::slice;
use std::ptr::copy;

macro_rules! vec_and_forget {
    ($size:expr, $a: ident, $block: block) => {
        let mut $a: Vec<f32> = Vec::from_raw_parts($a as *mut f32, $size, $size);
        $block
        std::mem::forget($a);
    };

    ($size:expr, $a: ident, $b: ident, $block: block) => {
        let $a: Vec<f32> = Vec::from_raw_parts($a as *mut f32, $size, $size);
        let mut $b: Vec<f32> = Vec::from_raw_parts($b as *mut f32, $size, $size);
        $block
        std::mem::forget($a);
        std::mem::forget($b);
    };

    ($size:expr, $a: ident, $b: ident, $c: ident, $block: block) => {
        let $a: Vec<f32> = Vec::from_raw_parts($a as *mut f32, $size, $size);
        let $b: Vec<f32> = Vec::from_raw_parts($b as *mut f32, $size, $size);
        let mut $c: Vec<f32> = Vec::from_raw_parts($c as *mut f32, $size, $size);
        $block
        std::mem::forget($a);
        std::mem::forget($b);
        std::mem::forget($c);
    };
}

#[no_mangle]
pub unsafe fn alloc(capacity: usize) -> *mut u8 {
    let mut memory = Vec::with_capacity(capacity);
    let ptr = memory.as_mut_ptr();
    std::mem::forget(memory);
    ptr
}

#[no_mangle]
pub unsafe fn dealloc(n: usize, ptr: *mut u8) {
    let _bytes: Vec<u8> = Vec::from_raw_parts(ptr, n, n);
}

#[no_mangle]
pub unsafe fn get(a: *const u8) -> f32 {
    let a = slice::from_raw_parts(a, 4);
    f32::from_le_bytes(a.try_into().unwrap())
}

#[no_mangle]
pub unsafe fn set(a: *mut u8, value: f32) {
    let a = slice::from_raw_parts_mut(a, 4).as_mut_ptr();
    let data = value.to_le_bytes().as_ptr();
    copy(data, a, 4);
}

#[no_mangle]
pub unsafe fn add(size: usize, a: *const u8, b: *const u8, c: *mut u8) {
    vec_and_forget!(size, a, b, c, {
        for i in 0..size {
            c[i] = a[i] + b[i];
        }
    });
}

#[no_mangle]
pub unsafe fn fill(size: usize, v: f32, a: *mut u8) {
    vec_and_forget!(size, a, {
        for i in 0..size {
            a[i] = v;
        }
    });
}

#[no_mangle]
pub unsafe fn sub(size: usize, a: *const u8, b: *const u8, c: *mut u8) {
    vec_and_forget!(size, a, b, c, {
        for i in 0..size {
            c[i] = a[i] - b[i];
        }
    });
}

#[no_mangle]
pub unsafe fn mul_add(size: usize, a: *const u8, b: *const u8, c: *mut u8) {
    vec_and_forget!(size, a, b, c, {
        for i in 0..size {
            c[i] += a[i] * b[i];
        }
    });
}

#[no_mangle]
pub unsafe fn sqrt(size: usize, a: *const u8, b: *const u8) {
    vec_and_forget!(size, a, b, {
        for i in 0..size {
            b[i] = a[i].sqrt();
        }
    });
}

#[no_mangle]
pub unsafe fn scale(size: usize, v: f32, a: *mut u8, b: *mut u8) {
    vec_and_forget!(size, a, b, {
        for i in 0..size {
            b[i] = a[i] * v;
        }
    });
}

#[no_mangle]
pub unsafe fn div(size: usize, a: *const u8, b: *const u8, c: *mut u8) {
    vec_and_forget!(size, a, b, c, {
        for i in 0..size {
            c[i] = a[i] / b[i];
        }
    });
}

#[no_mangle]
pub unsafe fn mul(size: usize, a: *const u8, b: *const u8, c: *mut u8) {
    vec_and_forget!(size, a, b, c, {
        for i in 0..size {
            c[i] = a[i] * b[i];
        }
    });
}

#[no_mangle]
pub unsafe fn min(size: usize, a: *const u8, b: *const u8, c: *mut u8) {
    vec_and_forget!(size, a, b, c, {
        for i in 0..size {
            c[i] = a[i].min(b[i]);
        }
    });
}
Enrique CR - 2024-01-22