From cae47e875c8e086fb319478d84eab304f2f9e69b Mon Sep 17 00:00:00 2001
From: unknown <david.hoksza@gmail.com>
Date: Wed, 23 Oct 2019 13:28:27 +0200
Subject: [PATCH] implemented creation of connected components graph

---
 talent/bioentity.py      |   6 +++
 talent/graph_utils.py    |   2 +-
 talent/layout.py         | 111 +++++++++++++++++++++++++++++----------
 talent/utils/transfer.py |  29 ++++++++--
 4 files changed, 116 insertions(+), 32 deletions(-)

diff --git a/talent/bioentity.py b/talent/bioentity.py
index 2ecd745..aaa318d 100644
--- a/talent/bioentity.py
+++ b/talent/bioentity.py
@@ -1070,6 +1070,7 @@ class BioEntity:
                  layouts:Dict[LAYOUT_TYPE,BioEntityLayout]=None):
         self._g = g
         self._id = id
+        self.__id_history: List[str] = []
         self._name = name
         self._type:int = type #preferably an SBO id 290 (this is then used e.g. in is_simple_molecule)
         self._element_id = element_id
@@ -1087,7 +1088,12 @@ class BioEntity:
     def get_id(self):
         return self._id
 
+    def get_original_id(self):
+        return self.__id_history[-1] if len(self.__id_history) > 0 else self.get_id()
+
     def set_id(self, id):
+        if id is not None:
+            self.__id_history.append(self._id)
         self._id = id
 
     def get_graph(self) -> nx.MultiGraph:
diff --git a/talent/graph_utils.py b/talent/graph_utils.py
index 4a06879..6814346 100644
--- a/talent/graph_utils.py
+++ b/talent/graph_utils.py
@@ -702,7 +702,7 @@ def get_species(g, node_id) -> 'be.Species':
     return get_node_data(get_node(g, node_id))
 
 
-def get_node_data(node, g:nx.MultiGraph = None) -> 'be.Species' or 'be.Reaction':
+def get_node_data(node: Dict or str, g:nx.MultiGraph = None) -> 'be.Species' or 'be.Reaction':
     if g:
         return get_node(g, node)['data']
     else:
diff --git a/talent/layout.py b/talent/layout.py
index 2370439..e83340a 100644
--- a/talent/layout.py
+++ b/talent/layout.py
@@ -3,7 +3,7 @@ import copy
 import sys
 import math
 import logging
-from typing import List, Any, Dict, Set, Tuple, Collection
+from typing import List, Any, Dict, Set, Tuple, Collection, NamedTuple
 from scipy.optimize import linear_sum_assignment
 import numpy as np
 
@@ -24,24 +24,31 @@ class NodesMapping:
     is added to the result graph, a mapping is added.
     """
     def __init__(self):
-        self.sourceTarget: Dict[str, List[str]] = {}
-        self.targetSource: Dict[str, List[str]] = {}
+        self.__sourceTarget: Dict[str, List[str]] = {}
+        self.__targetSource: Dict[str, List[str]] = {}
 
     def addMapping(self, source, target):
 
-        if source not in self.sourceTarget:
-            self.sourceTarget[source] = []
-        self.sourceTarget[source].append(target)
+        if source not in self.__sourceTarget:
+            self.__sourceTarget[source] = []
+        self.__sourceTarget[source].append(target)
 
-        if target not in self.targetSource:
-            self.targetSource[target] = []
-        self.targetSource[target].append(source)
+        if target not in self.__targetSource:
+            self.__targetSource[target] = []
+        self.__targetSource[target].append(source)
 
     def getTargetMapping(self, target)->List[Any]:
-        return self.targetSource[target]
+        return self.__targetSource[target]
 
     def getSourceMapping(self, source):
-        return self.sourceTarget[source]
+        return self.__sourceTarget[source]
+
+    def getAllTargetIds(self) -> List[str]:
+        return list(self.__targetSource.keys())
+
+    def getAllSourceIds(self) -> List[str]:
+        return list(self.__sourceTarget.keys())
+
 
 class Stats:
     med_one_degree_species_distance = 0
@@ -65,6 +72,13 @@ class MinMaxCoords:
             if m[i] > self.max[i]:
                 self.max[i] = m[i]
 
+# class CCGraphNode(NamedTuple):
+#     x: float
+#     y: float
+#     width: float
+#     height: float
+#     g: nx.MultiGraph
+
 
 def get_mapping_dicts(m):
     dicts = [{}, {}] # The first dictionary contains mapping from target to template, while the second mapping from template to target
@@ -1620,7 +1634,7 @@ def get_graph_stats(g, layout_key=None) -> Stats:
         for s in ods:
             coords_r = get_node_pos(nodes[list(nx.all_neighbors(g, s))[0]], layout_type=be.LAYOUT_TYPE.ORIGINAL)
             coords_s = get_node_pos(nodes[s], layout_type=be.LAYOUT_TYPE.ORIGINAL)
-            dists.append(l2_dist(coords_r, coords_s))
+            dists.append(gr.utils.dist(coords_r, coords_s))
 
         stats.med_one_degree_species_distance = np.median(dists)
     else:
@@ -1634,7 +1648,7 @@ def get_graph_stats(g, layout_key=None) -> Stats:
                 if len(set(nx.all_neighbors(g, reactions[i1])).intersection(nx.all_neighbors(g, reactions[i2]))) > 0:
                     c1 = get_node_pos(nodes[reactions[i1]], layout_type=be.LAYOUT_TYPE.ORIGINAL)
                     c2 = get_node_pos(nodes[reactions[i2]], layout_type=be.LAYOUT_TYPE.ORIGINAL)
-                    dists.append(l2_dist(c1, c2))
+                    dists.append(gr.utils.dist(c1, c2))
         if len(dists) > 0:
             stats.med_reactions_distance = np.median(dists)
 
@@ -2098,6 +2112,7 @@ def duplicate(g, s, r, dist):
     g.remove_edge(r, s)
     g.remove_node(s)
 
+
 def validate_layout(res: nx.MultiGraph, tgt: nx.MultiGraph):
 
     non_laid_out_nodes = []
@@ -2139,21 +2154,63 @@ def validate_layout(res: nx.MultiGraph, tgt: nx.MultiGraph):
             raise Exception('Mismatchs in reactions')
 
 
-# def normalize_vec(p):
-#     """
-#
-#     :param p: Expect to be numpy array
-#     :return:
-#     """
-#
-#     if (sum(p != [0, 0])) == 0: return p
-#
-#     # return p / math.sqrt(p[0] * p[0] + p[1] * p[1])
-#     return p / np.linalg.norm(p);
 
 
-def l2_dist(coords1, coords2):
-    return math.sqrt((coords1[0] - coords2[0]) ** 2 + (coords1[1] - coords2[1]) ** 2)
-    # return np.linalg.norm([coords1, coords2])
+
+def mc_create_node_mapping(gs: List[nx.MultiGraph]) -> NodesMapping:
+    nm: NodesMapping = NodesMapping()
+
+    for g in gs:
+        for id in g:
+            orig_id = gu.get_node_data(id, g).get_original_id()
+            nm.addMapping(orig_id, id)
+
+    return nm
+
+
+def mc_locate_orig_id(id: str, gs_nodes: List[Set[str]], nm:NodesMapping) -> int:
+    new_ids = set(nm.getSourceMapping(id))
+    for i in range(len(gs_nodes)):
+        if len(new_ids.intersection(gs_nodes[i])) > 0:
+            return i
+    assert False
+
+
+
+def mc_create_cc_graph(gs: List[nx.MultiGraph], g_tgt_orig: nx.MultiGraph) -> nx.MultiGraph:
+    nm = mc_create_node_mapping(gs)
+    cc_orig_ids = nm.getAllSourceIds()
+    orig_r_ids = gu.get_all_reaction_ids(g_tgt_orig)
+    connecting_r_ids = set(orig_r_ids).difference(cc_orig_ids)
+
+    gs_nodes = [set(g.nodes) for g in gs]
+
+    g_cc = nx.MultiGraph()
+    g_cc.add_nodes_from(list[range(len(gs))])
+    for r_id in orig_r_ids:
+        s_ids = list(g_tgt_orig[r_id])
+        g_ixs = set([mc_locate_orig_id(id=s_id, gs_nodes=gs_nodes, nm=nm) for s_id in s_ids])
+        if len(g_ixs) > 1:
+            # a reaction spanning multiple compartments
+            g_cc.add_edges_from(combinations(g_ixs, 2))
+
+    return g_cc
+
+
+
+def merge_compartments(cmp_gs: Dict[str, List[nx.MultiGraph]], g_tgt: nx.MultiGraph):
+    logging.info("Merging compartments")
+    # Create graph of connected components
+    gs_cc: List[nx.MultiGraph] = []
+    for gs in cmp_gs.values():
+        gs_cc += gs
+
+    # Take ccs in the order from mostly highly connected components (if equal, then by size) and lay
+    # given CC based on the reactions connecting CC with its neighbors and reposition when overlapping
+    g_cc = mc_create_node_mapping(gs=gs_cc)
+    centrality = nx.betweenness_centrality(g_cc)
+    for ix in sorted(centrality.items(), key=lambda kv: (kv[1], len(gs[kv[0]]))):
+        NotImplemented
+        # Get optimal position based
 
 
diff --git a/talent/utils/transfer.py b/talent/utils/transfer.py
index 23df100..9fa4fd1 100644
--- a/talent/utils/transfer.py
+++ b/talent/utils/transfer.py
@@ -7,8 +7,7 @@ import pickle
 
 from sged import sged_graph as ed
 
-from typing import List
-from typing import NamedTuple
+from typing import List, Dict, NamedTuple
 
 import sys
 sys.path.append(os.path.dirname(os.path.abspath(__file__))+"/../..")
@@ -316,11 +315,25 @@ def map_transfer_store(tgt:ReactionGraph, tmps: List[ReactionGraph], cache_path:
     return rv_fnames
 
 
+def group_gs_by_compartment(gs: List[nx.MultiGraph]) -> Dict[str, List[nx.MultiGraph]]:
+    rv: Dict[str, List[nx.MultiGraph]] = []
+
+    for g in gs:
+        assert(len(g) > 0)
+        s_id = gu.get_all_species_ids(g)[0]
+        cmp_name = gu.get_node_data(s_id, g).get_compartment_name()
+        if cmp_name not in rv:
+            rv[cmp_name] = []
+        rv[cmp_name].append(g)
+
+    return rv
+
+
 def transfer(tgt_name, tgt_fname, tgt_fmt, split_by_cmprtmnt, tmps_path, tmp_fmt,
              output_path,
              ddup_tgt=False,
              ddup_tmp=False, cache_path=None, tgt_sbml_source=None, tmp_cc_before_ddup=True,
-             separate_cmprtmnt=False):
+             separate_cmprtmnt=True):
     """
 
     :param tgt_name:
@@ -348,7 +361,7 @@ def transfer(tgt_name, tgt_fname, tgt_fmt, split_by_cmprtmnt, tmps_path, tmp_fmt
 
     if separate_cmprtmnt:
         tgt_ccs = process_tgt(tgt_name, tgt_fname, tgt_fmt, ddup_tgt, True)
-        tmps: List[ReactionGraph] = process_db(tmps_path, tmp_fmt, False, ddup_tmp, tmp_cc_before_ddup)
+        tmps: List[ReactionGraph] = process_db(tmps_path, tmp_fmt, True, ddup_tmp, tmp_cc_before_ddup)
 
         i = 0
         mtrss: List[List[MapTransferResult]] = []
@@ -358,8 +371,16 @@ def transfer(tgt_name, tgt_fname, tgt_fmt, split_by_cmprtmnt, tmps_path, tmp_fmt
             i += 1
 
         # Extract best mapping from each transfer result list
+        g_ress = [mtrs[0].bc.g for mtrs in mtrss]
         # Group results by compartment
+        cmp_gs: Dict[str, List[nx.MultiGraph]] = group_gs_by_compartment(gs=g_ress)
         # Create CC
+        g_tgt_complete: nx.MultiGraph = gu.load_graph(tgt_fname, tgt_fmt)
+        g_res = lt.merge_compartments(cmp_gs=cmp_gs, g_tgt=g_tgt_complete)
+        bc = BeautificationChain(g_res)
+        fname = "{}/{}.pkl".format(output_path, "test", i)
+        bc.save(fname)
+
 
     else:
         tgt_ccs = process_tgt(tgt_name, tgt_fname, tgt_fmt, ddup_tgt, split_by_cmprtmnt)
-- 
GitLab