pycoral.learn.imprinting

pycoral.learn.imprinting.engine

A weight imprinting engine that performs low-shot transfer-learning for image classification models.

For more information about how to use this API and how to create the type of model required, see Retrain a classification model on-device with weight imprinting.

class pycoral.learn.imprinting.engine.ImprintingEngine(model_path, keep_classes=False)[source]

Performs weight imprinting (transfer learning) with the given model.

Parameters
  • model_path (str) – Path to the .tflite model you want to retrain. This must be a model that’s specially-designed for this API. You can use our weight imprinting model that has a pre-trained base model, or you can train the base model yourself by following our guide to Retrain the base MobileNet model.

  • keep_classes (bool) – If True, keep the existing classes from the pre-trained model (and use training to add additional classes). If False, drop the existing classes and train the model to include new classes only.

property embedding_dim

Returns number of embedding dimensions.

property num_classes

Returns number of currently trained classes.

serialize_extractor_model()[source]

Returns embedding extractor model as bytes object.

serialize_model()[source]

Returns newly trained model as bytes object.

train(embedding, class_id)[source]

Trains the model with the given embedding for specified class.

You can use this to add new classes to the model or retrain classes that you previously added using this imprinting API.

Parameters
  • embedding (numpy.array) – The embedding vector for training specified single class.

  • class_id (int) – The label id for this class. The index must be either the number of existing classes (to add a new class to the model) or the index of an existing class that was trained using this imprinting API (you can’t retrain classes from the pre-trained model).

API version 2.0