summaryrefslogtreecommitdiff
blob: 9add28626a23a732acd5dc0cd44868c76c5b226a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
From abcebfe9714644d4e259e53b10e0e9417b5b864f Mon Sep 17 00:00:00 2001
From: Jerome Flesch <jflesch@openpaper.work>
Date: Sun, 21 Apr 2024 13:31:03 +0200
Subject: [PATCH] backend/guesswork/labels/sklearn: fix use of
 scipy.sparse.hstack() + numpy.zeros()

Closes #1111
---
 .../paperwork_backend/guesswork/label/sklearn/__init__.py   | 5 +++--
 paperwork-backend/src/paperwork_backend/model/fake.py       | 6 ++++++
 2 files changed, 9 insertions(+), 2 deletions(-)

diff --git a/paperwork-backend/src/paperwork_backend/guesswork/label/sklearn/__init__.py b/paperwork-backend/src/paperwork_backend/guesswork/label/sklearn/__init__.py
index b2af4350..8633211f 100644
--- a/paperwork-backend/src/paperwork_backend/guesswork/label/sklearn/__init__.py
+++ b/paperwork-backend/src/paperwork_backend/guesswork/label/sklearn/__init__.py
@@ -191,7 +191,8 @@ class UpdatableVectorizer(object):
             )
             if required_padding > 0:
                 doc_vector = numpy.hstack([
-                    doc_vector, numpy.zeros((required_padding,))
+                    doc_vector,
+                    numpy.zeros((required_padding,))
                 ])
             if sum_features is None:
                 sum_features = doc_vector
@@ -339,7 +340,7 @@ class Corpus(object):
             if required_padding > 0:
                 doc_vector = scipy.sparse.hstack([
                     scipy.sparse.csr_matrix(doc_vector),
-                    numpy.zeros((required_padding,))
+                    numpy.zeros((1, required_padding))
                 ])
             else:
                 doc_vector = scipy.sparse.csr_matrix(doc_vector)
diff --git a/paperwork-backend/src/paperwork_backend/model/fake.py b/paperwork-backend/src/paperwork_backend/model/fake.py
index 29beae97..f06fe18e 100644
--- a/paperwork-backend/src/paperwork_backend/model/fake.py
+++ b/paperwork-backend/src/paperwork_backend/model/fake.py
@@ -125,6 +125,12 @@ class Plugin(openpaperwork_core.PluginBase):
             if doc['url'] == doc_url:
                 out.update(doc['labels'])
 
+    def doc_has_labels_by_url(self, doc_url):
+        for doc in self.docs:
+            if doc['url'] == doc_url:
+                return True if len(doc["labels"]) > 0 else None
+        return None
+
     def doc_add_label_by_url(self, doc_url, label, color=None):
         if color is None:
             all_labels = set()
-- 
GitLab