ref: e916cf426dec2506baf74e4181c75655c4c2d9f6
parent: 1fbc5fdd4ee06c48e95afb2046b5645df61545be
author: Jan Buethe <jbuethe@amazon.de>
date: Tue Aug 1 06:35:29 EDT 2023
added .copy() to weights in wexchange
--- a/dnn/torch/weight-exchange/wexchange/torch/torch.py
+++ b/dnn/torch/weight-exchange/wexchange/torch/torch.py
@@ -39,14 +39,14 @@
assert gru.num_layers == 1
assert gru.bidirectional == False
- w_ih = gru.weight_ih_l0.detach().cpu().numpy()
- w_hh = gru.weight_hh_l0.detach().cpu().numpy()
+ w_ih = gru.weight_ih_l0.detach().cpu().numpy().copy()
+ w_hh = gru.weight_hh_l0.detach().cpu().numpy().copy()
if hasattr(gru, 'bias_ih_l0'):
- b_ih = gru.bias_ih_l0.detach().cpu().numpy()
+ b_ih = gru.bias_ih_l0.detach().cpu().numpy().copy()
else:
b_ih = None
if hasattr(gru, 'bias_hh_l0'):
- b_hh = gru.bias_hh_l0.detach().cpu().numpy()
+ b_hh = gru.bias_hh_l0.detach().cpu().numpy().copy()
else:
b_hh = None
@@ -81,11 +81,11 @@
def dump_torch_dense_weights(where, dense, name='dense', scale=1/128, sparse=False, diagonal=False, quantize=False):
- w = dense.weight.detach().cpu().numpy()
+ w = dense.weight.detach().cpu().numpy().copy()
if dense.bias is None:
b = np.zeros(dense.out_features, dtype=w.dtype)
else:
- b = dense.bias.detach().cpu().numpy()
+ b = dense.bias.detach().cpu().numpy().copy()
if isinstance(where, CWriter):
return print_dense_layer(where, name, w, b, scale=scale, format='torch', sparse=sparse, diagonal=diagonal, quantize=quantize)
@@ -110,11 +110,11 @@
def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=False):
- w = conv.weight.detach().cpu().numpy()
+ w = conv.weight.detach().cpu().numpy().copy()
if conv.bias is None:
b = np.zeros(conv.out_channels, dtype=w.dtype)
else:
- b = conv.bias.detach().cpu().numpy()
+ b = conv.bias.detach().cpu().numpy().copy()
if isinstance(where, CWriter):
@@ -141,7 +141,7 @@
def dump_torch_embedding_weights(where, emb):
os.makedirs(where, exist_ok=True)
- w = emb.weight.detach().cpu().numpy()
+ w = emb.weight.detach().cpu().numpy().copy()
np.save(os.path.join(where, 'weight.npy'), w)
--
⑨