shithub: opus

Download patch

ref: c76756e18a8a04bdcbbb4462770f423587725b26
parent: 8bdbbfa18d4697be7ba9fc47176809d669d33f77
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Sat Jul 17 22:24:21 EDT 2021

Adding sparse training for GRU B inputs

--- a/dnn/training_tf2/lpcnet.py
+++ b/dnn/training_tf2/lpcnet.py
@@ -116,6 +116,57 @@
                 #print(thresh, np.mean(mask))
             w[1] = p
             layer.set_weights(w)
+
+class SparsifyGRUB(Callback):
+    def __init__(self, t_start, t_end, interval, grua_units, density):
+        super(SparsifyGRUB, self).__init__()
+        self.batch = 0
+        self.t_start = t_start
+        self.t_end = t_end
+        self.interval = interval
+        self.final_density = density
+        self.grua_units = grua_units
+
+    def on_batch_end(self, batch, logs=None):
+        #print("batch number", self.batch)
+        self.batch += 1
+        if self.batch < self.t_start or ((self.batch-self.t_start) % self.interval != 0 and self.batch < self.t_end):
+            #print("don't constrain");
+            pass
+        else:
+            #print("constrain");
+            layer = self.model.get_layer('gru_b')
+            w = layer.get_weights()
+            p = w[0]
+            N = p.shape[0]
+            M = p.shape[1]//3
+            for k in range(3):
+                density = self.final_density[k]
+                if self.batch < self.t_end:
+                    r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
+                    density = 1 - (1-self.final_density[k])*(1 - r*r*r)
+                A = p[:, k*M:(k+1)*M]
+                #This is needed because of the CuDNNGRU strange weight ordering
+                A = np.reshape(A, (M, N))
+                A = np.transpose(A, (1, 0))
+                N2 = self.grua_units
+                A2 = A[:N2, :]
+                L=np.reshape(A2, (N2//4, 4, M//8, 8))
+                S=np.sum(L*L, axis=-1)
+                S=np.sum(S, axis=1)
+                SS=np.sort(np.reshape(S, (-1,)))
+                thresh = SS[round(M*N2//32*(1-density))]
+                mask = (S>=thresh).astype('float32');
+                mask = np.repeat(mask, 4, axis=0)
+                mask = np.repeat(mask, 8, axis=1)
+                A = np.concatenate([A2*mask, A[N2:,:]], axis=0)
+                #This is needed because of the CuDNNGRU strange weight ordering
+                A = np.transpose(A, (1, 0))
+                A = np.reshape(A, (N, M))
+                p[:, k*M:(k+1)*M] = A
+                #print(thresh, np.mean(mask))
+            w[0] = p
+            layer.set_weights(w)
             
 
 class PCMInit(Initializer):
--