From c0c2da78c50bdeb01394af12d75e52bfba7c5a52 Mon Sep 17 00:00:00 2001 From: Caj Larsson Date: Sat, 12 Aug 2023 16:41:03 +0200 Subject: [PATCH] Jit fails me --- jit-train-makemore.py | 91 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 jit-train-makemore.py diff --git a/jit-train-makemore.py b/jit-train-makemore.py new file mode 100644 index 0000000..72f1172 --- /dev/null +++ b/jit-train-makemore.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 + +from tinygrad.tensor import Tensor +from tinygrad.helpers import dtypes +from tinygrad.nn import Linear, Embedding +from tinygrad.nn.optim import SGD +from tinygrad.jit import TinyJit +from extra.training import sparse_categorical_crossentropy +import random +import numpy as np +import numpy.random as npr + +words = ['lorem', 'ipsum', 'dolor', 'sit', 'amet', 'consectetur', 'adipiscing', + 'elit', 'sed', 'do', 'eiusmod', 'tempor', 'incididunt', 'ut', 'labore', 'et', + 'dolore', 'magna', 'aliqua', 'ut', 'enim', 'ad', 'minim', 'veniam', 'quis', + 'nostrud', 'exercitation', 'ullamco', 'laboris', 'nisi', 'ut', 'aliquip', 'ex', + 'ea', 'commodo', 'consequat', 'duis', 'aute', 'irure', 'dolor', 'in', + 'reprehenderit', 'in', 'voluptate', 'velit', 'esse', 'cillum', 'dolore', 'eu', + 'fugiat', 'nulla', 'pariatur', 'excepteur', 'sint', 'occaecat', 'cupidatat', + 'non', 'proident', 'sunt', 'in', 'culpa', 'qui', 'officia', 'deserunt', + 'mollit', 'anim', 'id', 'est', 'laborum'] + +def atoi(char): + if char == '.': + return 0 + assert char >= "a" and char <= "z" + return ord(char) - ord("a") + 1 + +def itoa(char): + charint = int(char) + if charint == 0: + return "." + assert char >= 1 and char <= 26 + return chr(charint -1 + ord('a')) + +def dataset(words, block_size=3): + X, Y = [], [] + for word in words: + window = [0] * block_size # sliding window context + for ch in word + ".": + ix = atoi(ch) + X.append(window) + Y.append(ix) + window = window[1:] + [ix] + return np.array(X, dtype=np.float32), np.array(Y, dtype=np.float32) + +class Model: + def __init__(self, emb_size=10, hidden_n=100, vocab_size=27): + self.emb = Embedding(vocab_size, emb_size) + self.l1 = Linear(emb_size * block_size, hidden_n) + self.l2 = Linear(hidden_n, vocab_size) + + def __call__(self, x, training=True): + if training: + for p in self.parameters(): + p.requires_grad = True + emb = self.emb(x) + h = self.l1(emb.reshape(emb.shape[0], -1)).tanh() + logits = self.l2(h) + return logits + + def parameters(self): + return [self.l1.weight, self.l1.bias, self.l2.weight, self.l2.bias, self.emb.weight] + +@TinyJit +def jitloss(x,y): + return sparse_categorical_crossentropy(x, y).realize() + +block_size=3 +random.shuffle(words) +X,Y = dataset(words, block_size) +m = Model() +opt = SGD(m.parameters(), lr=0.1) + +for step in range(1000): + batch_ix = npr.randint(0, X.shape[0], (32,)) + x_batch, y_batch = Tensor(X[batch_ix], requires_grad=False), Y[batch_ix] + + logits = m(x_batch) + loss = jitloss(logits, y_batch) + + opt.zero_grad() + + loss.backward() + if m.parameters()[0].grad is None: + break + + opt.step() + + if step % 100 == 0: + print(f"Step {step+1} | Loss: {loss.numpy()}")