shithub: scrax

ref: 86a1bffad7063d1f51dfe44c2126bbdf8b7fe5c8
dir: /nonsense/matrix.js/

View raw version
import {MakeBlock, flatten, before, remove} from "./model.js"
import {NaturalSlot, compareName, Define} from "./core.js"
import {compile} from "./onnx.js"
import {toNumber} from "./compile.js"

export function extendStage()
{
	return {blocks, transform}
}

async function transform(blocks, sprite)
{
	for (let block of blocks) {
		if (block.type.shape === "reporter") continue
		if (block.stack?.length) await transform([block.stack[0]], sprite)
		let stack = []
		while (block) {
			if (flatten(block).every(block => isAllowed(blocks, block))) stack.push(block)
			else await convertReplacing(blocks, stack.splice(0), sprite)
			let index = block.parent?.stack?.indexOf(block)
			if (index === undefined) break
			block = block.parent.stack[index + 1]
		}
		await convertReplacing(blocks, stack, sprite)
	}
}

function isAllowed(blocks, block, defines = new Set())
{
	if (block.type.shape === "slot") return true
	if (conversions.has(block.type)) return true
	if (block.type.category !== "custom") return
	if (block.type.shape === "reporter") return true
	let define = blocks.find(block1 => block1.type === Define.type && compareName(block1.names[0], block.names[0]))
	if (defines.has(define)) return
	defines.add(define)
	return define.stack.every(block => isAllowed(blocks, block, defines))
}

async function convertReplacing(blocks0, blocks, sprite)
{
	if (blocks.length === 0) return
	let ctx = new Context()
	for (let block of blocks) convert(blocks0, block, ctx)
	let run = await compile(ctx)
	let Infer = MakeBlock({run: () => Object.assign(run(sprite), {sync: true}), name: () => ["infer"], async: true})
	let infer = Infer(blocks[0])
	if (blocks[0].parent) before(blocks[0], infer)
	for (let block of blocks) remove(block)
	let index = blocks0.indexOf(blocks[0])
	if (index >= 0) blocks0[index] = infer
}

class Context {
	
	nodes = []
	initialisers = new Map()
	inputVariables = new Set()
	inputLists = new Set()
	outputVariables = new Set()
	outputLists = new Set()
	#variables = new Map()
	#constants = new Map()
	
	#fresh()
	{
		return this.#bumpName("x")
	}
	
	getList(name, update = true)
	{
		if (update) this.inputLists.add(name)
		return this.#getName(`lists/${name}`)
	}
	
	getVariable(name, update = true)
	{
		if (update) this.inputVariables.add(name)
		return this.#getName(`vars/${name}`)
	}
	
	bumpList(name)
	{
		this.outputLists.add(name)
		return this.#bumpName(`lists/${name}`)
	}
	
	bumpVariable(name)
	{
		this.outputVariables.add(name)
		return this.#bumpName(`vars/${name}`)
	}
	
	make(opType, attribute, input, output)
	{
		if (!output) output = [this.#fresh()]
		attribute = Object.entries(attribute).map(([name, i]) => ({name, i, type: 2}))
		this.nodes.push({opType, attribute, input, output})
		return output[0]
	}
	
	constant(...values)
	{
		let data = values.join(",")
		let name0 = this.#constants.get(data)
		if (name0) return name0
		let name = this.#fresh()
		this.initialisers.set(name, values)
		return name
	}
	
	#bumpName(name)
	{
		let n = (this.#variables.get(name) ?? 0) + 1
		this.#variables.set(name, n)
		return name + "/" + n
	}
	
	#getName(name)
	{
		let n = this.#variables.get(name) ?? 0
		return name + "/" + n
	}
}

function convert(blocks, block, ctx)
{
	if (block.type.shape === "slot") {
		return ctx.constant(toNumber(block.value))
	}
	if (block.type.category === "custom") {
		for (let block of flattenCustom(blocks, block)) {
			convert(blocks, block, ctx)
		}
	}
	
	let inputs = block.inputs.map(block1 => convert(blocks, block1, ctx))
	conversions.get(block.type)(ctx, ...inputs, block)
}

function flattenCustom(_blocks, block)
{
	return [block]
}

let Add = MakeBlock({references: ["list", "list", "list"], name: (out, a, b) => [getIcon(), a, "+", b, "into", out]})
let Subtract = MakeBlock({references: ["list", "list", "list"], name: (out, a, b) => [getIcon(), a, "-", b, "into", out]})
let Multiply = MakeBlock({references: ["list", "list", "list"], name: (out, a, b) => [getIcon(), a, "*", b, "into", out]})
let Divide = MakeBlock({references: ["list", "list", "list"], name: (out, a, b) => [getIcon(), a, "/", b, "into", out]})
let MatrixMultiply = MakeBlock({references: ["list", "list", "list"], slots: [NaturalSlot], name: (out, a, b, j) => [getIcon(), a, "\xD7", b, "with", j, "into", out]})

let blocks = [Add, Subtract, Multiply, Divide, MatrixMultiply]

let conversions = new Map([
	[Add.type, (ctx, block) => binary(ctx, "Add", ctx.getList(block.names[0]), ctx.getList(block.names[1]),ctx.bumpList(block.names[2]))],
	[Subtract.type, (ctx, block) => binary(ctx, "Sub", ctx.getList(block.names[0]), ctx.getList(block.names[1]),ctx.bumpList(block.names[2]))],
	[Multiply.type, (ctx, block) => binary(ctx, "Mul", ctx.getList(block.names[0]), ctx.getList(block.names[1]),ctx.bumpList(block.names[2]))],
	[Divide.type, (ctx, block) => binary(ctx, "Div", ctx.getList(block.names[0]), ctx.getList(block.names[1]),ctx.bumpList(block.names[2]))],
	[MatrixMultiply.type, (ctx, j, block) =>
	{
		let zero = ctx.make("Cast", {to: 7}, [ctx.constant(0)])
		let one = ctx.make("Cast", {to: 7}, [ctx.constant(1)])
		
		let j0 = ctx.make("Cast", {to: 7}, [j])
		let j1 = ctx.make("Max", {}, [j0, one])
		let j2 = ctx.make("Sub", {}, [j1, one])
		
		let shape1 = ctx.make("Shape", {}, [ctx.getList(block.names[0])])
		let shape2 = ctx.make("Shape", {}, [ctx.getList(block.names[1])])
		
		let a1 = ctx.make("Add", {}, [shape1, j2])
		let a2 = ctx.make("Add", {}, [shape2, j2])
		
		let b1 = ctx.make("Div", {}, [a1, j1])
		let b2 = ctx.make("Div", {}, [a2, j1])
		
		let c1 = ctx.make("Mul", {}, [b1, j1])
		let c2 = ctx.make("Mul", {}, [b2, j1])
		
		let d1 = ctx.make("Sub", {}, [c1, shape1])
		let d2 = ctx.make("Sub", {}, [c2, shape2])
		
		let padding1 = ctx.make("Concat", {axis: 0}, [zero, d1])
		let padding2 = ctx.make("Concat", {axis: 0}, [zero, d2])
		
		let vector1 = ctx.make("Pad", {}, [ctx.getList(block.names[0]), padding1])
		let vector2 = ctx.make("Pad", {}, [ctx.getList(block.names[1]), padding2])
		
		let mShape1 = ctx.make("Concat", {axis: 0}, [b1, j1])
		let mShape2 = ctx.make("Concat", {axis: 0}, [j1, b2])
		
		let matrix1 = ctx.make("Reshape", {allowzero: 1}, [vector1, mShape1])
		let matrix2 = ctx.make("Reshape", {allowzero: 1}, [vector2, mShape2])
		
		let result = ctx.make("MatMul", {}, [matrix1, matrix2])
		let shape = ctx.make("Mul", {}, [b1, b2])
		
		ctx.make("Reshape", {allowzero: 1}, [result, shape], [ctx.bumpList(block.names[2])])
	}],
])

function binary(ctx, op, a, b, out)
{
	let zero = ctx.make("Cast", {to: 7}, [ctx.constant(0)])
	let shape1 = ctx.make("Shape", {}, [a])
	let shape2 = ctx.make("Shape", {}, [b])
	let shape = ctx.make("Max", {}, [shape1, shape2])
	let a1 = ctx.make("Pad", {}, [a, ctx.make("Concat", {axis: 0}, [zero, ctx.make("Sub", {}, [shape, shape1])])])
	let b1 = ctx.make("Pad", {}, [b, ctx.make("Concat", {axis: 0}, [zero, ctx.make("Sub", {}, [shape, shape2])])])
	ctx.make(op, {}, [a1, b1], [out])
}

let icon = `<svg viewbox="0 0 24 24" fill="#48E" stroke="#444"><path d="M19.46,8l0.79-1.75L22,5.46c0.39-0.18,0.39-0.73,0-0.91l-1.75-0.79L19.46,2c-0.18-0.39-0.73-0.39-0.91,0l-0.79,1.75 L16,4.54c-0.39,0.18-0.39,0.73,0,0.91l1.75,0.79L18.54,8C18.72,8.39,19.28,8.39,19.46,8z M11.5,9.5L9.91,6 C9.56,5.22,8.44,5.22,8.09,6L6.5,9.5L3,11.09c-0.78,0.36-0.78,1.47,0,1.82l3.5,1.59L8.09,18c0.36,0.78,1.47,0.78,1.82,0l1.59-3.5 l3.5-1.59c0.78-0.36,0.78-1.47,0-1.82L11.5,9.5z M18.54,16l-0.79,1.75L16,18.54c-0.39,0.18-0.39,0.73,0,0.91l1.75,0.79L18.54,22 c0.18,0.39,0.73,0.39,0.91,0l0.79-1.75L22,19.46c0.39-0.18,0.39-0.73,0-0.91l-1.75-0.79L19.46,16 C19.28,15.61,18.72,15.61,18.54,16z"/></svg>`

function getIcon()
{
	let span = document.createElement("span")
	span.insertAdjacentHTML("beforeend", icon)
	let svg = span.querySelector("svg")
	svg.style.setProperty("height", "1.5em")
	svg.style.setProperty("vertical-align", "middle")
	return span
}