Finding Triplets with Neo4j

A user had an interesting Neo4j question on Stack Overflow the other day:

I have two types of nodes in my graph. One type is Testplan and the other is Tag. Testplans are tagged to Tags. I want most common pairs of Tags that share the same Testplans with a Tag having a specific name. I have been able to achieve the most common Tags sharing the same Testplan with one Tag, but getting confused when trying to do it for pairs of Tags.

Their cypher query looked like this:

MATCH (kw1:Tag)<-[e:TAGGED]-(tp1:Testplan)-[e2:TAGGED]->(kw2:Tag) 
WHERE = "result"
RETURN kw1,kw2,count(tp1)
ORDER BY count(tp1) DESC

…and their results were a little off:

Kw1                   kw2                       count(tp1)
“result”              “error”                   104
“result”              “prerequisites”           89
“result”              “alpha”                   63

What they wanted was something closer to this:

Kw1                   kw2                           count(tp1)
“result”              “error”,”prerequisites”       70
“result”              “error”,”alpha”               63

The answer to their query was this:

MATCH (kw1:Tag)<-[e:TAGGED]-(tp1:Testplan)-[e2:TAGGED]->(kw2:Tag),
WHERE = "result"
AND ID(kw2)<ID(kw3)
RETURN kw2, kw3,count(tp1)
ORDER BY count(tp1) DESC

They were trying to group in pairs of kw2, but what they needed to do is add a second TAGGED relationship to a kw3, so every path would guarantee to have the original tag, kw2 and kw3. I am using the ID(kw2)<ID(kw3) trick to eliminate duplicates from returning as valid paths.

This satisfied the answer, but they ran into another problem:

in place of two if I want 3/4/5 such combinations, what should be the ideal way? I tried to tailor the query without understanding much, but this is making Neo4j crash even on a high performance computer.

Cypher is a pattern matching language and it is meant to deal with paths as the core unit of work. This works great for most things, but sometimes it means doing a lot more work than necessary. It’s not a problem, what we have to do is convert our cypher query to a stored procedure and try to be smarter about the work we are doing.

I’m calling this procedure “combinations” and we will pass in the first tag, the number of tag combinations (2 for doubles, 3 for triples, 4 for quads, etc) and a limit to only see the top sets of tags.

    @Procedure(name = "com.maxdemarzi.combinations", mode = Mode.READ)
    @Description("CALL com.maxdemarzi.combinations(tag, number, limit) - find combinations")
    public Stream<MapResult> getCombinations(@Name("tag") String tag, @Name("size") Number k, @Name("limit") Number limit) throws IOException {

The first thing we’ll want to do is find the starting Tag node:

// We start by finding the tag
Node tagNode = db.findNode(Labels.Tag, "name", tag);

I need to keep a count of all the combinations we find, so we’ll use a Map. We have to use a dynamic object as our key because we’re not really sure how many tags we will asked to keep track of in the set.

// We will keep track of the combinations as a String of node ids
Map<String, Integer> counts = new HashMap<>();

With that out of the way, we want to traverse from our tag to all of the Test Plans connected to our starting Tag.

// Next we find all the test plans tagged by this tag
for (Relationship tagged : tagNode.getRelationships(Direction.INCOMING, RelationshipTypes.TAGGED)) {
     Node testPlan = tagged.getStartNode();

For each Test Plan, we want to find all the tags that belong to it, and keep these in a list. Notice I’m using Relationship::getEndNodeId() to just grab the id of the node on the other end of our relationship instead of the node object.

            // Then find all the tags for this test plan
            ArrayList<Long> tags = new ArrayList<>();
            for (Relationship taggedToo : testPlan.getRelationships(Direction.OUTGOING, RelationshipTypes.TAGGED)) {

Cypher traverses the graph checking for “Relationship Uniqueness”. In other words it does not traverse over the same relationship twice in a single path. We don’t want to include our starting Tag in our combinations because it will always be there and we know that already. So we can remove it from our tags list>

// We need to remove our starting tag

Here is where things get interesting. We are going to generate all the possible combinations for the tags found connected to our test plan. If our tags had ids of [0, 2, 1], we will get a list of [0,1],[0,2],[1,2]. Notice it will be sorted, but I’m assuming the test plans do not get tagged by the same tag twice. We could use a Set if that was the case.

// Get all the sorted combinations of tags
List<long[]> list = new ArrayList<>();
combinations(k.intValue() - 1 , -> l).toArray(), list );

Now we have to add our list of combinations to our count. We’ll cheat and use Arrays::toString on the long[] to turn the [0,1] into an actual “[0,1]”.

            // For each combination add it to the counts map, or increment it
            for (long[] item : list) {
                counts.merge(Arrays.toString(item), 1, (c, one) -> c + 1);

Once we’ve done this for all the test plans, we sort the count:

        // Sort the results in descending order
        List<Map.Entry<String,Integer>> results =  new ArrayList<>(counts.entrySet());
        results.sort( Map.Entry.<String, Integer>comparingByValue().reversed() );

…and reduce the results to just the top x given to us by the limit variable.

// Get the top x results
results = results.subList(0,Math.min(results.size(), limit.intValue()));

The last piece of the puzzle is returning the results. To do this we will turn each into a map, using the key of each result entry to find the tag nodes and populate the names of the tags.

       return -> {
            Map<String, Object> triple = new HashMap<>();
            int count = 1;
            for (String id : result.getKey().substring(1, result.getKey().length() - 1).replace(" ", "").split(",")) {
                Node found = db.getNodeById(Long.valueOf(id));
                String name = (String)found.getProperty("name");
                triple.put("tag" + count++, name);
            triple.put("count", result.getValue());
            return new MapResult(triple);

So assuming we created some sample data, we could call the stored procedure like this:

CALL com.maxdemarzi.combinations(tag, number, limit)
CALL com.maxdemarzi.combinations('first tag', 3, 10);    

…and our results would look like:

  "tag1": "second tag",
  "count": 4,
  "tag2": "third tag"
  "tag1": "third tag",
  "count": 3,
  "tag2": "fourth tag"
  "tag1": "second tag",
  "count": 3,
  "tag2": "fourth tag"

To find quads instead of triples, we just change the middle variable:

CALL com.maxdemarzi.combinations('first tag', 4, 10);    

As always the source code is on github. Is there a better way of doing this? Probably. This was just the first thing that came to mind. If you find a smarter way, send me a pull request!

…and yes I know it’s “triples” not “triplets” but aren’t those babies cute?

Tagged , , , , , , , ,

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

%d bloggers like this: