Skip to content
This repository has been archived by the owner on Apr 27, 2023. It is now read-only.

Impossible to do classification training (ValueError) #276

Open
LuciusV opened this issue Aug 15, 2021 · 2 comments
Open

Impossible to do classification training (ValueError) #276

LuciusV opened this issue Aug 15, 2021 · 2 comments

Comments

@LuciusV
Copy link

LuciusV commented Aug 15, 2021

Hello! I wanted to train megnet model to classification of structure set (some property is zero or not), so I prepared train data as
column with string values 'zero' or 'nonzero'.

Then model.train method (model is MEGNetModel loaded from band_classification.hdf5) fails with
ValueError: Failed to convert a NumPy array to a Tensor (Unsupported object type list).
Using bool values instead of string also doesn't work.
If I use 0 and 1 for train, then trained model gives me float values (near 0 or 1 if trained well enough), but this is not reliable and I would like to do classification machine learning, with probably more than two classes.
Could you please provide any advice how to change model properties to allow that?

there is function i use to make model:

def gnn_model(n_targets=1):
    model_form = MEGNetModel.from_file('band_classification.hdf5')
    embedding_layer = [i for i in model_form.layers if i.name.startswith('embedding')][0]
    embedding = embedding_layer.get_weights()[0]
    #print('Embedding matrix dimension is ', embedding.shape)
    model = MEGNetModel(100,2,ntarget=n_targets)
    # find the embedding layer  index in all the model layers
    embedding_layer_index = [i for i, j in enumerate(model.layers) if j.name.startswith('atom_embedding')][0]

    # Set the weights to our previous embedding
    model.layers[embedding_layer_index].set_weights([embedding])

    # Freeze the weights
    model.layers[embedding_layer_index].trainable = False
    return model
@LuciusV LuciusV closed this as completed Aug 15, 2021
@LuciusV LuciusV reopened this Aug 15, 2021
@chc273
Copy link
Contributor

chc273 commented Aug 30, 2021

Did you solve the problem? @LuciusV

Also make sure your numpy version is 1.19, since the 1.20 versions have some incompatibility issues

@LuciusV
Copy link
Author

LuciusV commented Aug 31, 2021

Did you solve the problem? @LuciusV

Also make sure your numpy version is 1.19, since the 1.20 versions have some incompatibility issues

Thank you for pointing this. I will create a new virtual environment with numpy =1.19 and try, because I was using newer numpy .

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants