Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/core/imagelib/SegIEPolys.py
628 views
1
import numpy as np
2
import cv2
3
from enum import IntEnum
4
5
6
class SegIEPolyType(IntEnum):
7
EXCLUDE = 0
8
INCLUDE = 1
9
10
11
12
class SegIEPoly():
13
def __init__(self, type=None, pts=None, **kwargs):
14
self.type = type
15
16
if pts is None:
17
pts = np.empty( (0,2), dtype=np.float32 )
18
else:
19
pts = np.float32(pts)
20
self.pts = pts
21
self.n_max = self.n = len(pts)
22
23
def dump(self):
24
return {'type': int(self.type),
25
'pts' : self.get_pts(),
26
}
27
28
def identical(self, b):
29
if self.n != b.n:
30
return False
31
return (self.pts[0:self.n] == b.pts[0:b.n]).all()
32
33
def get_type(self):
34
return self.type
35
36
def add_pt(self, x, y):
37
self.pts = np.append(self.pts[0:self.n], [ ( float(x), float(y) ) ], axis=0).astype(np.float32)
38
self.n_max = self.n = self.n + 1
39
40
def undo(self):
41
self.n = max(0, self.n-1)
42
return self.n
43
44
def redo(self):
45
self.n = min(len(self.pts), self.n+1)
46
return self.n
47
48
def redo_clip(self):
49
self.pts = self.pts[0:self.n]
50
self.n_max = self.n
51
52
def insert_pt(self, n, pt):
53
if n < 0 or n > self.n:
54
raise ValueError("insert_pt out of range")
55
self.pts = np.concatenate( (self.pts[0:n], pt[None,...].astype(np.float32), self.pts[n:]), axis=0)
56
self.n_max = self.n = self.n+1
57
58
def remove_pt(self, n):
59
if n < 0 or n >= self.n:
60
raise ValueError("remove_pt out of range")
61
self.pts = np.concatenate( (self.pts[0:n], self.pts[n+1:]), axis=0)
62
self.n_max = self.n = self.n-1
63
64
def get_last_point(self):
65
return self.pts[self.n-1].copy()
66
67
def get_pts(self):
68
return self.pts[0:self.n].copy()
69
70
def get_pts_count(self):
71
return self.n
72
73
def set_point(self, id, pt):
74
self.pts[id] = pt
75
76
def set_points(self, pts):
77
self.pts = np.array(pts)
78
self.n_max = self.n = len(pts)
79
80
def mult_points(self, val):
81
self.pts *= val
82
83
84
85
class SegIEPolys():
86
def __init__(self):
87
self.polys = []
88
89
def identical(self, b):
90
polys_len = len(self.polys)
91
o_polys_len = len(b.polys)
92
if polys_len != o_polys_len:
93
return False
94
95
return all ([ a_poly.identical(b_poly) for a_poly, b_poly in zip(self.polys, b.polys) ])
96
97
def add_poly(self, ie_poly_type):
98
poly = SegIEPoly(ie_poly_type)
99
self.polys.append (poly)
100
return poly
101
102
def remove_poly(self, poly):
103
if poly in self.polys:
104
self.polys.remove(poly)
105
106
def has_polys(self):
107
return len(self.polys) != 0
108
109
def get_poly(self, id):
110
return self.polys[id]
111
112
def get_polys(self):
113
return self.polys
114
115
def get_pts_count(self):
116
return sum([poly.get_pts_count() for poly in self.polys])
117
118
def sort(self):
119
poly_by_type = { SegIEPolyType.EXCLUDE : [], SegIEPolyType.INCLUDE : [] }
120
121
for poly in self.polys:
122
poly_by_type[poly.type].append(poly)
123
124
self.polys = poly_by_type[SegIEPolyType.INCLUDE] + poly_by_type[SegIEPolyType.EXCLUDE]
125
126
def __iter__(self):
127
for poly in self.polys:
128
yield poly
129
130
def overlay_mask(self, mask):
131
h,w,c = mask.shape
132
white = (1,)*c
133
black = (0,)*c
134
for poly in self.polys:
135
pts = poly.get_pts().astype(np.int32)
136
if len(pts) != 0:
137
cv2.fillPoly(mask, [pts], white if poly.type == SegIEPolyType.INCLUDE else black )
138
139
def dump(self):
140
return {'polys' : [ poly.dump() for poly in self.polys ] }
141
142
def mult_points(self, val):
143
for poly in self.polys:
144
poly.mult_points(val)
145
146
@staticmethod
147
def load(data=None):
148
ie_polys = SegIEPolys()
149
if data is not None:
150
if isinstance(data, list):
151
# Backward comp
152
ie_polys.polys = [ SegIEPoly(type=type, pts=pts) for (type, pts) in data ]
153
elif isinstance(data, dict):
154
ie_polys.polys = [ SegIEPoly(**poly_cfg) for poly_cfg in data['polys'] ]
155
156
ie_polys.sort()
157
158
return ie_polys
159