AmpliGraph

Open source Python library that predicts links between concepts in a knowledge graph.

View the GitHub repository ImageLink

AmpliGraph is a suite of neural machine learning models for relational Learning, a branch of machine learning that deals with supervised learning on knowledge graphs.

_images/kg_lp.png

Use AmpliGraph if you need to:

  • Discover new knowledge from an existing knowledge graph.
  • Complete large knowledge graphs with missing statements.
  • Generate stand-alone knowledge graph embeddings.
  • Develop and evaluate a new relational model.

AmpliGraph’s machine learning models generate knowledge graph embeddings, vector representations of concepts in a metric space:

_images/kg_lp_step1.png

It then combines embeddings with model-specific scoring functions to predict unseen and novel links:

_images/kg_lp_step2.png

Key Features

  • Intuitive APIs: AmpliGraph APIs are designed to reduce the code amount required to learn models that predict links in knowledge graphs.
  • GPU-Ready: AmpliGraph is based on TensorFlow, and it is designed to run seamlessly on CPU and GPU devices - to speed-up training.
  • Extensible: Roll your own knowledge graph embeddings model by extending AmpliGraph base estimators.

Modules

AmpliGraph includes the following submodules:

  • Input: Helper functions to load datasets (knowledge graphs).
  • Latent Feature Models: knowledge graph embedding models. AmpliGraph contains: TransE, DistMult, ComplEx, HolE. (More to come!)
  • Evaluation: Metrics and evaluation protocols to assess the predictive power of the models.

How to Cite

If you like AmpliGraph and you use it in your project, why not starring the project on GitHub!

GitHub stars

If you instead use AmpliGraph in an academic publication, cite as:

@misc{ampligraph,
  author= {Luca Costabello and
           Sumit Pai and
           Chan Le Van and
           Rory McGrath and
           Nick McCarthy},
  title = {{AmpliGraph: a Library for Representation Learning on Knowledge Graphs}},
  month = mar,
  year  = 2019,
  doi   = {10.5281/zenodo.2595043},
  url   = {https://doi.org/10.5281/zenodo.2595043}
}
https://zenodo.org/badge/DOI/10.5281/zenodo.2595043.svg

Installation

Prerequisites

  • Linux Box
  • Python ≥ 3.6
Provision a Virtual Environment

Create and activate a virtual environment (conda)

conda create --name ampligraph python=3.6
source activate ampligraph
Install TensorFlow

AmpliGraph is built on TensorFlow 1.x. Install from pip or conda:

CPU-only

pip install tensorflow==1.12.0

or 

conda install tensorflow=1.12.0

GPU support

pip install tensorflow-gpu==1.12.0

or 

conda install tensorflow-gpu=1.12.0

Install AmpliGraph

Install the latest stable release from pip:

pip install ampligraph

If instead you want the most recent development version, you can clone the repository and install from source as: See the How to Contribute guide for details.

git clone https://github.com/Accenture/AmpliGraph.git
git checkout develop
cd AmpliGraph
pip install -e .

Sanity Check

>> import ampligraph
>> ampligraph.__version__
'1.0.1'

Background

Knowledge graphs are graph-based knowledge bases whose facts are modeled as relationships between entities. Knowledge graph research led to broad-scope graphs such as DBpedia [ABK+07], WordNet [Pri10], and YAGO [SKW07]. Countless domain-specific knowledge graphs have also been published on the web, giving birth to the so-called Web of Data [BHBL11].

Formally, a knowledge graph \(\mathcal{G}=\{ (sub,pred,obj)\} \subseteq \mathcal{E} \times \mathcal{R} \times \mathcal{E}\) is a set of \((sub,pred,obj)\) triples, each including a subject \(sub \in \mathcal{E}\), a predicate \(pred \in \mathcal{R}\), and an object \(obj \in \mathcal{E}\). \(\mathcal{E}\) and \(\mathcal{R}\) are the sets of all entities and relation types of \(\mathcal{G}\).

Knowledge graph embedding models are neural architectures that encode concepts from a knowledge graph (i.e. entities \(\mathcal{E}\) and relation types \(\mathcal{R}\)) into low-dimensional, continuous vectors \(\in \mathcal{R}^k\). Such textit{knowledge graph embeddings} have applications in knowledge graph completion, entity resolution, and link-based clustering, just to cite a few [NMTG16]. Knowledge graph embeddings are learned by training a neural architecture over a graph. Although such architectures vary, the training phase always consists in minimizing a loss function \(\mathcal{L}\) that includes a scoring function \(f_{m}(t)\), i.e. a model-specific function that assigns a score to a triple \(t=(sub,pred,obj)\) .

The goal of the optimization procedure is learning optimal embeddings, such that the scoring function is able to assign high scores to positive statements and low scores to statements unlikely to be true. Existing models propose scoring functions that combine the embeddings \(\mathbf{e}_{sub},\mathbf{e}_{pred}, \mathbf{e}_{obj} \in \mathcal{R}^k\) of the subject, predicate, and object of triple \(t=(sub,pred,obj)\) using different intuitions: TransE [BUGD+13] relies on distances, DistMult [YYH+14] and ComplEx [TWR+16] are bilinear-diagonal models, HolE [NRP+16] uses circular correlation. While the above models can be interpreted as multilayer perceptrons, others such as ConvE include convolutional layers [DMSR18].

As example, the scoring function of TransE computes a similarity between the embedding of the subject \(\mathbf{e}_{sub}\) translated by the embedding of the predicate \(\mathbf{e}_{pred}\) and the embedding of the object \(\mathbf{e}_{obj}\), using the \(L_1\) or \(L_2\) norm \(||\cdot||\):

\[f_{TransE}=-||\mathbf{e}_{sub} + \mathbf{e}_{pred} - \mathbf{e}_{obj}||_n\]

Such scoring function is then used on positive and negative triples \(t^+, t^-\) in the loss function. This can be for example a pairwise margin-based loss, as shown in the equation below:

\[\mathcal{L}(\Theta) = \sum_{t^+ \in \mathcal{G}}\sum_{t^- \in \mathcal{N}}max(0, [\gamma + f_{m}(t^-;\Theta) - f_{m}(t^+;\Theta)])\]

where \(\Theta\) are the embeddings learned by the model, \(f_{m}\) is the model-specific scoring function, \(\gamma \in \mathcal{R}\) is the margin and \(\mathcal{N}\) is a set of negative triples generated with a corruption heuristic [BUGD+13].

API

AmpliGraph includes the following submodules:

Datasets

Helper functions to load knowledge graphs.

Note

It is recommended to set the AMPLIGRAPH_DATA_HOME environment variable:

export AMPLIGRAPH_DATA_HOME=/YOUR/PATH/TO/datasets

When attempting to load a dataset, the module will first check if AMPLIGRAPH_DATA_HOME is set. If it is, it will search this location for the required dataset. If the dataset is not found it will be downloaded and placed in this directory.

If AMPLIGRAPH_DATA_HOME has not been set the databases will be saved in the following directory:

~/ampligraph_datasets
Dataset-Specific Loaders

Use these helpers functions to load datasets used in graph representation learning literature. The functions will automatically download the datasets if they are not present in ~/ampligraph_datasets or at the location set in AMPLIGRAPH_DATA_HOME.

load_wn18() Load the WN18 dataset
load_fb15k() Load the FB15k dataset
load_fb15k_237([clean_unseen]) Load the FB15k-237 dataset
load_yago3_10() Load the YAGO3-10 dataset
load_wn18rr([clean_unseen]) Load the WN18RR dataset

Datasets Summary

Dataset Train Valid Test Entities Relations
FB15K-237 272,115 17,535 20,466 14,541 237
WN18RR 86,835 3,034 3,134 40,943 11
FB15K 483,142 50,000 59,071 14,951 1,345
WN18 141,442 5,000 5,000 40,943 18
YAGO3-10 1,079,040 5,000 5,000 123,182 37

Hint

WN18 and FB15k include a large number of inverse relations, and its use in experiments has been deprecated. Use WN18RR and FB15K-237 instead.

Warning

FB15K-237’s validation set contains 8 unseen entities over 9 triples. The test set has 29 unseen entities, distributed over 28 triples. WN18RR’s validation set contains 198 unseen entities over 210 triples. The test set has 209 unseen entities, distributed over 210 triples.

Generic Loaders

Functions to load custom knowledge graphs from disk.

load_from_csv(directory_path, file_name[, …]) Load a knowledge graph from a csv file
load_from_ntriples(folder_name, file_name[, …]) Load RDF ntriples
load_from_rdf(folder_name, file_name[, …]) Load an RDF file

Models

This module includes neural graph embedding models and support functions.

Knowledge graph embedding models are neural architectures that encode concepts from a knowledge graph (i.e. entities \(\mathcal{E}\) and relation types \(\mathcal{R}\)) into low-dimensional, continuous vectors \(\in \mathcal{R}^k\). Such knowledge graph embeddings have applications in knowledge graph completion, entity resolution, and link-based clustering, just to cite a few [NMTG16].

Knowledge Graph Embedding Models
RandomBaseline([seed]) Random baseline
TransE([k, eta, epochs, batches_count, …]) Translating Embeddings (TransE)
DistMult([k, eta, epochs, batches_count, …]) The DistMult model
ComplEx([k, eta, epochs, batches_count, …]) Complex embeddings (ComplEx)
HolE([k, eta, epochs, batches_count, seed, …]) Holographic Embeddings
Anatomy of a Model

Knowledge graph embeddings are learned by training a neural architecture over a graph. Although such architectures vary, the training phase always consists in minimizing a loss function \(\mathcal{L}\) that includes a scoring function \(f_{m}(t)\), i.e. a model-specific function that assigns a score to a triple \(t=(sub,pred,obj)\).

AmpliGraph models include the following components:

AmpliGraph comes with a number of such components. They can be used in any combination to come up with a model that performs sufficiently well for the dataset of choice.

AmpliGraph features a number of abstract classes that can be extended to design new models:

EmbeddingModel([k, eta, epochs, …]) Abstract class for embedding models
Loss(eta, hyperparam_dict[, verbose]) Abstract class for loss function.
Regularizer(hyperparam_dict[, verbose]) Abstract class for Regularizer.
Scoring functions

Existing models propose scoring functions that combine the embeddings \(\mathbf{e}_{s},\mathbf{r}_{p}, \mathbf{e}_{o} \in \mathcal{R}^k\) of the subject, predicate, and object of a triple \(t=(s,p,o)\) according to different intuitions:

  • TransE [BUGD+13] relies on distances. The scoring function computes a similarity between the embedding of the subject translated by the embedding of the predicate and the embedding of the object, using the \(L_1\) or \(L_2\) norm \(||\cdot||\):
\[f_{TransE}=-||\mathbf{e}_{s} + \mathbf{r}_{p} - \mathbf{e}_{o}||_n\]
\[f_{DistMult}=\langle \mathbf{r}_p, \mathbf{e}_s, \mathbf{e}_o \rangle\]
\[f_{ComplEx}=Re(\langle \mathbf{r}_p, \mathbf{e}_s, \overline{\mathbf{e}_o} \rangle)\]
\[f_{HolE}=\mathbf{w}_r \cdot (\mathbf{e}_s \star \mathbf{e}_o) = \frac{1}{k}\mathcal{F}(\mathbf{w}_r)\cdot( \overline{\mathcal{F}(\mathbf{e}_s)} \odot \mathcal{F}(\mathbf{e}_o))\]

Other models such ConvE include convolutional layers [DMSR18] (will be available in AmpliGraph future releases).

Loss Functions

AmpliGraph includes a number of loss functions commonly used in literature. Each function can be used with any of the implemented models. Loss functions are passed to models as hyperparameter, and they can be thus used during model selection.

PairwiseLoss(eta[, loss_params, verbose]) Pairwise, max-margin loss.
NLLLoss(eta[, loss_params, verbose]) Negative log-likelihood loss.
AbsoluteMarginLoss(eta[, loss_params, verbose]) Absolute margin , max-margin loss.
SelfAdversarialLoss(eta[, loss_params, verbose]) Self adversarial sampling loss.
Regularizers

AmpliGraph includes a number of regularizers that can be used with the loss function. LPRegularizer supports L1, L2, and L3.

LPRegularizer([regularizer_params, verbose]) Performs LP regularization
Optimizers

The goal of the optimization procedure is learning optimal embeddings, such that the scoring function is able to assign high scores to positive statements and low scores to statements unlikely to be true.

We support SGD-based optimizers provided by TensorFlow, by setting the optimizer argument in a model initializer. Best results are currently obtained with Adam.

Utils Functions

Models can be saved and restored from disk. This is useful to avoid re-training a model.

save_model(model, loc) Save a trained model to disk.
restore_model(loc) Restore a saved model from disk.

Evaluation

The module includes performance metrics for neural graph embeddings models, along with model selection routines, negatives generation, and an implementation of the learning-to-rank-based evaluation protocol used in literature.

Metrics

Learning-to-rank metrics to evaluate the performance of neural graph embedding models.

rank_score(y_true, y_pred[, pos_lab]) Rank of a triple
mrr_score(ranks) Mean Reciprocal Rank (MRR)
mr_score(ranks) Mean Rank (MR)
hits_at_n_score(ranks, n) Hits@N
Negatives Generation

Negatives generation routines. These are corruption strategies based on the Local Closed-World Assumption (LCWA).

generate_corruptions_for_eval(X, …[, …]) Generate corruptions for evaluation.
generate_corruptions_for_fit(X[, …]) Generate corruptions for training.
Evaluation & Model Selection

Functions to evaluate the predictive power of knowledge graph embedding models, and routines for model selection.

evaluate_performance(X, model[, …]) Evaluate the performance of an embedding model.
select_best_model_ranking(model_class, X, …) Model selection routine for embedding models.
Helper Functions

Utilities and support functions for evaluation procedures.

train_test_split_no_unseen(X[, test_size, seed]) Split into train and test sets.
create_mappings(X) Create string-IDs mappings for entities and relations.
to_idx(X, ent_to_idx, rel_to_idx) Convert statements (triples) into integer IDs.

How to Contribute

Git Repo and Issue Tracking

https://img.shields.io/github/stars/Accenture/AmpliGraph.svg?style=social&label=Star&maxAge=3600GitHub stars

AmpliGraph repository is available on GitHub.

A list of open issues is available here.

The AmpliGraph Slack channel is available here.

How to Contribute

We welcome community contributions, whether they are new models, tests, or documentation.

You can contribute to AmpliGraph in many ways:

Adding Your Own Model

The landscape of knowledge graph embeddings evolves rapidly. We welcome new models as a contribution to AmpliGraph, which has been built to provide a shared codebase to guarantee a fair evalaution and comparison acros models.

You can add your own model by raising a pull request.

To get started, read the documentation on how current models have been implemented.

Clone and Install in editable mode

Clone the repository and checkout the develop branch. Install from source with pip. use the -e flag to enable editable mode:

git clone https://github.com/Accenture/AmpliGraph.git
git checkout develop
cd AmpliGraph
pip install -e .

Unit Tests

To run all the unit tests:

$ pytest tests

See pytest documentation for additional arguments.

Documentation

The project documentation is based on Sphinx and can be built on your local working copy as follows:

cd docs
make clean autogen html

The above generates an HTML version of the documentation under docs/_built/html.

Packaging

To build an AmpliGraph custom wheel, do the following:

pip wheel --wheel-dir dist --no-deps .

Examples

Train and evaluate an embedding model

import numpy as np
from ampligraph.datasets import load_wn18
from ampligraph.latent_features import ComplEx
from ampligraph.evaluation import evaluate_performance, mrr_score, hits_at_n_score

def main():

    # load Wordnet18 dataset:
    X = load_wn18()

    # Initialize a ComplEx neural embedding model with pairwise loss function:
    # The model will be trained for 300 epochs.
    model = ComplEx(batches_count=10, seed=0, epochs=20, k=150, eta=10,
                    # Use adam optimizer with learning rate 1e-3
                    optimizer='adam', optimizer_params={'lr':1e-3},
                    # Use pairwise loss with margin 0.5
                    loss='pairwise', loss_params={'margin':0.5},
                    # Use L2 regularizer with regularizer weight 1e-5
                    regularizer='LP', regularizer_params={'p':2, 'lambda':1e-5}, 
                    # Enable stdout messages (set to false if you don't want to display)
                    verbose=True)

    # For evaluation, we can use a filter which would be used to filter out 
    # positives statements created by the corruption procedure.
    # Here we define the filter set by concatenating all the positives
    filter = np.concatenate((X['train'], X['valid'], X['test']))
    
    # Fit the model on training and validation set
    model.fit(X['train'], 
              early_stopping = True,
              early_stopping_params = \
                      {
                          'x_valid': X['valid'],       # validation set
                          'criteria':'hits10',         # Uses hits10 criteria for early stopping
                          'burn_in': 100,              # early stopping kicks in after 100 epochs
                          'check_interval':20,         # validates every 20th epoch
                          'stop_interval':5,           # stops if 5 successive validation checks are bad.
                          'x_filter': filter,          # Use filter for filtering out positives 
                          'corruption_entities':'all', # corrupt using all entities
                          'corrupt_side':'s+o'         # corrupt subject and object (but not at once)
                      }
              )

    

    # Run the evaluation procedure on the test set (with filtering). 
    # To disable filtering: filter_triples=None
    # Usually, we corrupt subject and object sides separately and compute ranks
    ranks = evaluate_performance(X['test'], 
                                 model=model, 
                                 filter_triples=filter,
                                 use_default_protocol=True, # corrupt subj and obj separately while evaluating
                                 verbose=True)

    # compute and print metrics:
    mrr = mrr_score(ranks)
    hits_10 = hits_at_n_score(ranks, n=10)
    print("MRR: %f, Hits@10: %f" % (mrr, hits_10))
    # Output: MRR: 0.886406, Hits@10: 0.935000

if __name__ == "__main__":
    main()

Model selection

from ampligraph.datasets import load_wn18
from ampligraph.latent_features import ComplEx
from ampligraph.evaluation import select_best_model_ranking

def main():

    # load Wordnet18 dataset:
    X_dict = load_wn18()

    model_class = ComplEx

    # Use the template given below for doing grid search. 
    param_grid = {
                     "batches_count": [10],
                     "seed": 0,
                     "epochs": [4000],
                     "k": [100, 50],
                     "eta": [5,10],
                     "loss": ["pairwise", "nll", "self_adversarial"],
                     # We take care of mapping the params to corresponding classes
                     "loss_params": {
                         #margin corresponding to both pairwise and adverserial loss
                         "margin": [0.5, 20], 
                         #alpha corresponding to adverserial loss
                         "alpha": [0.5]
                     },
                     "embedding_model_params": {
                         # generate corruption using all entities during training
                         "negative_corruption_entities":"all"
                     },
                     "regularizer": [None, "LP"],
                     "regularizer_params": {
                         "p": [2],
                         "lambda": [1e-4, 1e-5]
                     },
                     "optimizer": ["adam"],
                     "optimizer_params":{
                         "lr": [0.01, 0.0001]
                     },
                     "verbose": True
                 }

    # Train the model on all possibile combinations of hyperparameters.
    # Models are validated on the validation set.
    # It returnes a model re-trained on training and validation sets.
    best_model, best_params, best_mrr_train, \
    ranks_test, mrr_test = select_best_model_ranking(model_class, # Class handle of the model to be used
                                                     # Dataset 
                                                     X_dict,          
                                                     # Parameter grid
                                                     param_grid,      
                                                     # Use filtered set for eval
                                                     use_filter=True, 
                                                     # corrupt subject and objects separately during eval
                                                     use_default_protocol=True, 
                                                     # Log all the model hyperparams and evaluation stats
                                                     verbose=True)
    print(type(best_model).__name__, best_params, best_mrr_train, mrr_test)

if __name__ == "__main__":
    main()

Get the embeddings

import numpy as np
from ampligraph.latent_features import ComplEx

model = ComplEx(batches_count=1, seed=555, epochs=20, k=10)
X = np.array([['a', 'y', 'b'],
              ['b', 'y', 'a'],
              ['a', 'y', 'c'],
              ['c', 'y', 'a'],
              ['a', 'y', 'd'],
              ['c', 'y', 'd'],
              ['b', 'y', 'c'],
              ['f', 'y', 'e']])
model.fit(X)
model.get_embeddings(['f','e'], type='entity')

Save and restore a model


import numpy as np

from ampligraph.latent_features import ComplEx, save_model, restore_model

model = ComplEx(batches_count=2, seed=555, epochs=20, k=10)

X = np.array([['a', 'y', 'b'],
            ['b', 'y', 'a'],
            ['a', 'y', 'c'],
            ['c', 'y', 'a'],
            ['a', 'y', 'd'],
            ['c', 'y', 'd'],
            ['b', 'y', 'c'],
            ['f', 'y', 'e']])

model.fit(X)

EXAMPLE_LOC = 'saved_models'

# Use the trained model to predict 
y_pred_before = model.predict(np.array([['f', 'y', 'e'], ['b', 'y', 'd']]))
print(y_pred_before)

# Save the model
save_model(model, EXAMPLE_LOC)

# Restore the model
restored_model = restore_model(EXAMPLE_LOC)

# Use the restored model to predict
y_pred_after = restored_model.predict(np.array([['f', 'y', 'e'], ['b', 'y', 'd']]))
print(y_pred_after)

# Assert that the before and after values are same
assert(y_pred_before==y_pred_after)

Performance

Predictive Performance

We report the filtered MR, MRR, Hits@1,3,10 for the most common datasets used in literature.

FB15K-237

Model MR MRR Hits@1 Hits@3 Hits@10 Hyperparameters
TransE 153 0.31 0.22 0.35 0.51 batches_count: 60; embedding_model_params: norm: 1; epochs: 4000; eta: 50; k: 1000; loss: self_adversarial; loss_params: alpha: 0.5; margin: 5; optimizer: adam; optimizer_params: lr: 0.0001; seed: 0; normalize_ent_emb: false
DistMult 568 0.29 0.20 0.32 0.47 batches_count: 50; epochs: 4000; eta: 50; k: 400; loss: self_adversarial; loss_params: alpha: 1; margin: 1; optimizer: adam; optimizer_params: lr: 0.0005; regularizer: LP; regularizer_params: lambda: 0.0001; p: 2; seed: 0; normalize_ent_emb: False;
ComplEx 519 0.30 0.20 0.33 0.48 batches_count: 50; epochs: 4000; eta: 30; k: 350; loss: self_adversarial; loss_params: alpha: 1; margin: 0.5; optimizer: adam; optimizer_params: lr: 0.0001; seed: 0
HolE 297 0.28 0.19 0.31 0.46 batches_count: 50; epochs: 4000; eta: 30; k: 350; loss: self_adversarial; loss_params: alpha: 1 margin: 0.5; optimizer: adam; optimizer_params: lr: 0.0001; seed: 0

Note

FB15K-237 validation and test sets include triples with entities that do not occur in the training set. We found 8 unseen entities in the validation set and 29 in the test set. In the experiments we excluded the triples where such entities appear (9 triples in from the validation set and 28 from the test set).

WN18RR

Model MR MRR Hits@1 Hits@3 Hits@10 Hyperparameters
TransE 1536 0.23 0.07 0.35 0.51 batches_count: 100; embedding_model_params: norm: 1; epochs: 4000; eta: 20; k: 200; loss: self_adversarial; loss_params: margin: 1; optimizer: adam; optimizer_params: lr: 0.0001; regularizer: LP; regularizer_params: lambda: 1.0e-05; p: 1; seed: 0; normalize_ent_emb: false
DistMult 6853 0.44 0.42 0.45 0.50 batches_count: 25; epochs: 4000; eta: 20; k: 200; loss: self_adversarial; loss_params: margin: 1; optimizer: adam; optimizer_params: lr: 0.0005; seed: 0; normalize_ent_emb: false
ComplEx 8214 0.44 0.41 0.45 0.50 batches_count: 10; epochs: 4000; eta: 20; k: 200; loss: nll; loss_params: margin: 1; optimizer: adam; optimizer_params: lr: 0.0005; seed: 0
HolE 7305 0.47 0.43 0.48 0.53 batches_count: 50; epochs: 4000; eta: 20; k: 200; loss: self_adversarial; loss_params: margin: 1; optimizer: adam; optimizer_params: lr: 0.0005; seed: 0;

Note

WN18RR validation and test sets include triples with entities that do not occur in the training set. We found 198 unseen entities in the validation set and 209 in the test set. In the experiments we excluded the triples where such entities appear (210 triples in from the validation set and 210 from the test set).

FB15K

Warning

The dataset includes a large number of inverse relations, and its use in experiments has been deprecated. Use FB15k-237 instead.

Model MR MRR Hits@1 Hits@3 Hits@10 Hyperparameters
TransE 105 0.55 0.39 0.68 0.79 batches_count: 10; embedding_model_params: norm: 1; epochs: 4000; eta: 5; k: 150; loss: pairwise; loss_params: margin: 0.5; optimizer: adam; optimizer_params: lr: 0.0001; regularizer: LP; regularizer_params: lambda: 0.0001; p: 2; seed: 0; normalize_ent_emb: false
DistMult 177 0.79 0.74 0.82 0.86 batches_count: 50; epochs: 4000; eta: 20; k: 200; loss: self_adversarial; loss_params: margin: 1; optimizer: adam; optimizer_params: lr: 0.0005; seed: 0; normalize_ent_emb: false
ComplEx 188 0.79 0.76 0.82 0.86 batches_count: 100; epochs: 4000; eta: 20; k: 200; loss: self_adversarial; loss_params: margin: 1; optimizer: adam; optimizer_params: lr: 0.0005; seed: 0
HolE 212 0.80 0.76 0.83 0.87 batches_count: 50; epochs: 4000; eta: 20; k: 200; loss: self_adversarial; loss_params: margin: 1; optimizer: adam; optimizer_params: lr: 0.0005; seed: 0

WN18

Warning

The dataset includes a large number of inverse relations, and its use in experiments has been deprecated. Use WN18RR instead.

Model MR MRR Hits@1 Hits@3 Hits@10 Hyperparameters
TransE 446 0.50 0.18 0.81 0.89 batches_count: 10; embedding_model_params: norm: 1; epochs: 4000; eta: 5; k: 150; loss: pairwise; loss_params: margin: 0.5; optimizer: adam; optimizer_params: lr: 0.0001; regularizer: LP; regularizer_params: lambda: 0.0001; p: 2; seed: 0; normalize_ent_emb: false
DistMult 746 0.83 0.73 0.92 0.95 batches_count: 50; epochs: 4000; eta: 20; k: 200; loss: nll; loss_params: margin: 1; optimizer: adam; optimizer_params: lr: 0.0005; seed: 0; normalize_ent_emb: false
ComplEx 715 0.94 0.94 0.95 0.95 batches_count: 50; epochs: 4000; eta: 20; k: 200; loss: nll; loss_params: margin: 1; optimizer: adam; optimizer_params: lr: 0.0005; seed: 0
HolE 658 0.94 0.93 0.94 0.95 batches_count: 50; epochs: 4000; eta: 20; k: 200; loss: self_adversarial; loss_params: margin: 1; optimizer: adam; optimizer_params: lr: 0.0005; seed: 0

To reproduce the above results:

$ cd experiments
$ python predictive_performance.py

Note

Running predictive_performance.py on all datasets, for all models takes ~13 hours on an Intel Xeon Gold 6142, 64 GB Ubuntu 16.04 box equipped with a Tesla V100 16GB.

Experiments can be limited to specific models-dataset combinations as follows:

$ python predictive_performance.py -h
usage: predictive_performance.py [-h] [-d {fb15k,fb15k-237,wn18,wn18rr}]
                                 [-m {complex,transe,distmult,hole}]

optional arguments:
  -h, --help            show this help message and exit
  -d {fb15k,fb15k-237,wn18,wn18rr}, --dataset {fb15k,fb15k-237,wn18,wn18rr}
  -m {complex,transe,distmult,hole}, --model {complex,transe,distmult,hole}

Runtime Performance

Training the models on FB15K-237 (k=200, eta=2, batches_count=100, loss=nll), on an Intel Xeon Gold 6142, 64 GB Ubuntu 16.04 box equipped with a Tesla V100 16GB gives the following runtime report:

model seconds/epoch
ComplEx 3.19
TransE 3.26
DistMult 2.61
HolE 3.21

Bibliography

[ABK+07]Sören Auer, Christian Bizer, Georgi Kobilarov, Jens Lehmann, Richard Cyganiak, and Zachary Ives. Dbpedia: a nucleus for a web of open data. In The semantic web, 722–735. Springer, 2007.
[BHBL11]Christian Bizer, Tom Heath, and Tim Berners-Lee. Linked data: the story so far. In Semantic services, interoperability and web applications: emerging concepts, 205–227. IGI Global, 2011.
[BUGD+13]Antoine Bordes, Nicolas Usunier, Alberto Garcia-Duran, Jason Weston, and Oksana Yakhnenko. Translating embeddings for modeling multi-relational data. In Advances in neural information processing systems, 2787–2795. 2013.
[DMSR18]Tim Dettmers, Pasquale Minervini, Pontus Stenetorp, and Sebastian Riedel. Convolutional 2d knowledge graph embeddings. In Procs of AAAI. 2018. URL: https://www.aaai.org/ocs/index.php/AAAI/AAAI18/paper/view/17366.
[HOSM17]Takuo Hamaguchi, Hidekazu Oiwa, Masashi Shimbo, and Yuji Matsumoto. Knowledge transfer for out-of-knowledge-base entities: A graph neural network approach. IJCAI International Joint Conference on Artificial Intelligence, pages 1802–1808, 2017.
[HS17]Katsuhiko Hayashi and Masashi Shimbo. On the equivalence of holographic and complex embeddings for link prediction. CoRR, 2017. URL: http://arxiv.org/abs/1702.05563, arXiv:1702.05563.
[MBS13]Farzaneh Mahdisoltani, Joanna Biega, and Fabian M Suchanek. Yago3: a knowledge base from multilingual wikipedias. In CIDR. 2013.
[NMTG16]Maximilian Nickel, Kevin Murphy, Volker Tresp, and Evgeniy Gabrilovich. A review of relational machine learning for knowledge graphs. Procs of the IEEE, 104(1):11–33, 2016.
[NRP+16]Maximilian Nickel, Lorenzo Rosasco, Tomaso A Poggio, and others. Holographic embeddings of knowledge graphs. In AAAI, 1955–1961. 2016.
[Pri10]Princeton. About wordnet. Web, 2010. https://wordnet.princeton.edu.
[SKW07]Fabian M Suchanek, Gjergji Kasneci, and Gerhard Weikum. Yago: a core of semantic knowledge. In Procs of WWW, 697–706. ACM, 2007.
[SDNT19]Zhiqing Sun, Zhi-Hong Deng, Jian-Yun Nie, and Jian Tang. Rotate: knowledge graph embedding by relational rotation in complex space. In International Conference on Learning Representations. 2019. URL: https://openreview.net/forum?id=HkgEQnRqYQ.
[TCP+15]Kristina Toutanova, Danqi Chen, Patrick Pantel, Hoifung Poon, Pallavi Choudhury, and Michael Gamon. Representing text for joint embedding of text and knowledge bases. In Proceedings of the 2015 Conference on Empirical Methods in Natural Language Processing, 1499–1509. 2015.
[TWR+16]Théo Trouillon, Johannes Welbl, Sebastian Riedel, Éric Gaussier, and Guillaume Bouchard. Complex embeddings for simple link prediction. In International Conference on Machine Learning, 2071–2080. 2016.
[YYH+14]Bishan Yang, Wen-tau Yih, Xiaodong He, Jianfeng Gao, and Li Deng. Embedding entities and relations for learning and inference in knowledge bases. arXiv preprint, 2014.

Changelog

1.0.1

  • evaluation protocol now ranks object and subjects corruptions separately
  • Corruption generation can now use entites from current batch only
  • FB15k-237, WN18RR loaders filter out unseen triples by default
  • Removed some unused arguments
  • Improved documentation
  • Minor bugfixing

1.0.0

  • TransE
  • DistMult
  • ComplEx
  • FB15k, WN18, FB15k-237, WN18RR, YAGO3-10 loaders
  • generic loader for csv files
  • RDF, ntriples loaders
  • Learning to rank evaluation protocol
  • Tensorflow-based negatives generation
  • save/restore capabilities for models
  • pairwise loss
  • nll loss
  • self-adversarial loss
  • absolute margin loss
  • Model selection routine
  • LCWA corruption strategy for training and eval
  • rank, Hits@N, MRR scores functions

About

AmpliGraph is maintained by Accenture Labs Dublin.

Contact us

The AmpliGraph Slack channel is available here.

You can contact us by email at about@ampligraph.org.

How to Cite

If you like AmpliGraph and you use it in your project, why not starring the project on GitHub!

https://img.shields.io/github/stars/Accenture/AmpliGraph.svg?style=social&label=Star&maxAge=3600GitHub stars

If you instead use AmpliGraph in an academic publication, cite as:

@misc{ampligraph,
 author= {Luca Costabello and
          Sumit Pai and
          Chan Le Van and
          Rory McGrath and
          Nick McCarthy},
 title = {{AmpliGraph: a Library for Representation Learning on Knowledge Graphs}},
 month = mar,
 year  = 2019,
 doi   = {10.5281/zenodo.2595043},
 url   = {https://doi.org/10.5281/zenodo.2595043}
}

https://zenodo.org/badge/DOI/10.5281/zenodo.2595043.svgDOI

Contributors

Active contributors (in alphabetical order)