update post cropping notebook (#75)
2
.github/workflows/ci.yml
поставляемый
@ -27,8 +27,6 @@ jobs:
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -e .
|
||||
python -m spacy download en_core_web_md
|
||||
python -m textblob.download_corpora
|
||||
- name: Run pytest linux (linux-only)
|
||||
if: matrix.os == 'ubuntu-22.04'
|
||||
run: |
|
||||
|
||||
@ -1,48 +1,41 @@
|
||||
import os
|
||||
import ntpath
|
||||
from PIL import Image
|
||||
from matplotlib.patches import ConnectionPatch
|
||||
import cv2
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
|
||||
# use this function to visualize the matches
|
||||
def plot_matches(img1, img2, keypoints1, keypoints2):
|
||||
fig, axes = plt.subplots(1, 2, figsize=(16, 7))
|
||||
|
||||
# draw images
|
||||
axes[0].imshow(img1)
|
||||
axes[1].imshow(img2)
|
||||
|
||||
# draw matches
|
||||
for kp1, kp2 in zip(keypoints1, keypoints2):
|
||||
c = np.random.rand(3)
|
||||
con = ConnectionPatch(
|
||||
xyA=kp1,
|
||||
coordsA=axes[0].transData,
|
||||
xyB=kp2,
|
||||
coordsB=axes[1].transData,
|
||||
color=c,
|
||||
)
|
||||
fig.add_artist(con)
|
||||
axes[0].plot(*kp1, color=c, marker="x")
|
||||
axes[1].plot(*kp2, color=c, marker="x")
|
||||
|
||||
plt.show()
|
||||
MIN_MATCH_COUNT = 6
|
||||
FLANN_INDEX_KDTREE = 1
|
||||
|
||||
|
||||
# use this function to visualize the matches from sift
|
||||
def draw_matches(matches, img1, img2, kp1, kp2):
|
||||
MIN_MATCH_COUNT = 4
|
||||
def draw_matches(
|
||||
matches: List,
|
||||
img1: np.ndarray,
|
||||
img2: np.ndarray,
|
||||
kp1: List[cv2.KeyPoint],
|
||||
kp2: List[cv2.KeyPoint],
|
||||
) -> None:
|
||||
"""Visualize the matches from SIFT.
|
||||
|
||||
Args:
|
||||
matches (list[cv2.Match]): List of cv2.Match matches on the image.
|
||||
img1 (np.ndarray): The reference image.
|
||||
img2 (np.ndarray): The social media post.
|
||||
kp1 (list[cv2.KeyPoint]): List of keypoints from the first image.
|
||||
kp2 (list[cv2.KeyPoint]): List of keypoints from the second image.
|
||||
"""
|
||||
if len(matches) > MIN_MATCH_COUNT:
|
||||
# Estimate homography between template and scene
|
||||
src_pts = np.float32([kp1[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2)
|
||||
dst_pts = np.float32([kp2[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2)
|
||||
|
||||
M = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)[0]
|
||||
|
||||
if not isinstance(M, np.ndarray):
|
||||
print("Could not match images for drawing.")
|
||||
return
|
||||
# Draw detected template in scene image
|
||||
h = img1.shape[0]
|
||||
w = img1.shape[1]
|
||||
@ -50,9 +43,7 @@ def draw_matches(matches, img1, img2, kp1, kp2):
|
||||
-1, 1, 2
|
||||
)
|
||||
dst = cv2.perspectiveTransform(pts, M)
|
||||
|
||||
img2 = cv2.polylines(img2, [np.int32(dst)], True, 255, 3, cv2.LINE_AA)
|
||||
|
||||
h1 = img1.shape[0]
|
||||
h2 = img2.shape[0]
|
||||
w1 = img1.shape[1]
|
||||
@ -61,60 +52,92 @@ def draw_matches(matches, img1, img2, kp1, kp2):
|
||||
nheight = max(h1, h2)
|
||||
hdif = int((h2 - h1) / 2)
|
||||
newimg = np.zeros((nheight, nwidth, 3), np.uint8)
|
||||
|
||||
for i in range(3):
|
||||
newimg[hdif : hdif + h1, :w1, i] = img1
|
||||
newimg[:h2, w1 : w1 + w2, i] = img2
|
||||
|
||||
# Draw SIFT keypoint matches
|
||||
for m in matches:
|
||||
pt1 = (int(kp1[m.queryIdx].pt[0]), int(kp1[m.queryIdx].pt[1] + hdif))
|
||||
pt2 = (int(kp2[m.trainIdx].pt[0] + w1), int(kp2[m.trainIdx].pt[1]))
|
||||
cv2.line(newimg, pt1, pt2, (255, 0, 0))
|
||||
|
||||
plt.imshow(newimg)
|
||||
plt.show()
|
||||
else:
|
||||
print("Not enough matches are found - %d/%d" % (len(matches), MIN_MATCH_COUNT))
|
||||
|
||||
|
||||
# compute matches from sift
|
||||
def matching_points(img1, img2):
|
||||
def matching_points(
|
||||
img1: np.ndarray, img2: np.ndarray
|
||||
) -> Tuple[cv2.DMatch, List[cv2.KeyPoint], List[cv2.KeyPoint]]:
|
||||
"""Computes keypoint matches using the SIFT algorithm between two images.
|
||||
|
||||
Args:
|
||||
img1 (np.ndarray): The reference image.
|
||||
img2 (np.ndarray): The social media post.
|
||||
Returns:
|
||||
cv2.DMatch: List of filtered keypoint matches.
|
||||
cv2.KeyPoint: List of keypoints from the first image.
|
||||
cv2.KeyPoint: List of keypoints from the second image.
|
||||
"""
|
||||
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
|
||||
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
|
||||
sift = cv2.xfeatures2d.SIFT_create()
|
||||
sift = cv2.SIFT_create()
|
||||
kp1, des1 = sift.detectAndCompute(img1, None)
|
||||
kp2, des2 = sift.detectAndCompute(img2, None)
|
||||
des1 = np.float32(des1)
|
||||
des2 = np.float32(des2)
|
||||
# Initialize and use FLANN
|
||||
FLANN_INDEX_KDTREE = 1
|
||||
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
|
||||
search_params = dict(checks=50)
|
||||
flann = cv2.FlannBasedMatcher(index_params, search_params)
|
||||
matches = flann.knnMatch(des1, des2, k=2)
|
||||
|
||||
filtered_matches = []
|
||||
for m, n in matches:
|
||||
if m.distance < 0.7 * n.distance:
|
||||
filtered_matches.append(m)
|
||||
|
||||
# draw_matches(filtered_matches, img1, img2, kp1, kp2)
|
||||
|
||||
return filtered_matches, kp1, kp2
|
||||
|
||||
|
||||
# extract match points from matches
|
||||
def kp_from_matches(matches, kp1, kp2):
|
||||
def kp_from_matches(matches, kp1: np.ndarray, kp2: np.ndarray) -> Tuple[Tuple, Tuple]:
|
||||
"""Extract the match indices from the keypoints.
|
||||
|
||||
Args:
|
||||
kp1 (np.ndarray): Key points of the matches,
|
||||
kp2 (np.ndarray): Key points of the matches,
|
||||
Returns:
|
||||
tuple: Index of the descriptor in the list of train descriptors.
|
||||
tuple: index of the descriptor in the list of query descriptors.
|
||||
"""
|
||||
kp1 = np.float32([kp1[m.queryIdx].pt for m in matches])
|
||||
kp2 = np.float32([kp2[m.trainIdx].pt for m in matches])
|
||||
return kp1, kp2
|
||||
|
||||
|
||||
# estimate a crop corner for posts image via matches
|
||||
def compute_crop_corner(
|
||||
matches, kp1, kp2, region=30, h_margin=28, v_margin=5, min_match=6
|
||||
):
|
||||
matches: cv2.DMatch,
|
||||
kp1: np.ndarray,
|
||||
kp2: np.ndarray,
|
||||
region: int = 30,
|
||||
h_margin: int = 0,
|
||||
v_margin: int = 5,
|
||||
min_match: int = 6,
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""Estimate the position on the image from where to crop.
|
||||
|
||||
Args:
|
||||
matches (cv2.DMatch): The matched objects on the image.
|
||||
kp1 (np.ndarray): Key points of the matches for the reference image.
|
||||
kp2 (np.ndarray): Key points of the matches for the social media posts.
|
||||
region (int, optional): Area to consider around the keypoints.
|
||||
Defaults to 30.
|
||||
h_margin (int, optional): Horizontal margin to subtract from the minimum
|
||||
horizontal position. Defaults to 0.
|
||||
v_margin (int, optional): Vertical margin to subtract from the minimum
|
||||
vertical position. Defaults to 5.
|
||||
min_match: Minimum number of matches required. Defaults to 6.
|
||||
Returns:
|
||||
tuple, optional: Tuple of vertical and horizontal crop corner coordinates.
|
||||
"""
|
||||
kp1, kp2 = kp_from_matches(matches, kp1, kp2)
|
||||
ys = kp2[:, 1]
|
||||
covers = []
|
||||
@ -123,147 +146,195 @@ def compute_crop_corner(
|
||||
series = pd.Series(ys_c)
|
||||
is_between = series.between(0, region)
|
||||
covers.append(is_between.sum())
|
||||
|
||||
covers = np.array(covers)
|
||||
if covers.max() < min_match:
|
||||
return None
|
||||
|
||||
kp_id = ys[covers.argmax()]
|
||||
v = int(kp_id) - v_margin if int(kp_id) > v_margin else int(kp_id)
|
||||
hs = []
|
||||
for kp in kp2:
|
||||
if 0 <= kp[1] - v <= region:
|
||||
hs.append(kp[0])
|
||||
|
||||
# do not use margin if h < image width/2, else use margin
|
||||
h = int(np.min(hs)) - h_margin if int(np.min(hs)) > h_margin else 0
|
||||
|
||||
return v, h
|
||||
|
||||
|
||||
# crop the posts image
|
||||
def crop_posts_image(
|
||||
ref_view, view, plt_match=False, plt_crop=False, correct_margin=700
|
||||
):
|
||||
"""
|
||||
get file lists from dir and sub dirs
|
||||
ref_view: List,
|
||||
view: np.ndarray,
|
||||
) -> Union[None, Tuple[np.ndarray, int, int, int]]:
|
||||
"""Crop the social media post to exclude additional comments. Sometimes also crops the
|
||||
image part of the post - this is put back in later.
|
||||
|
||||
|
||||
ref_view: ref_view for crop the posts images
|
||||
view: posts image that need cropping
|
||||
rte: None - not cropped, or (crop_view, number of matches)
|
||||
Args:
|
||||
ref_views (list): List of all the reference images (as numpy arrays) that signify
|
||||
below which regions should be cropped.
|
||||
view (np.ndarray): The image to crop.
|
||||
Returns:
|
||||
np.ndarray: The cropped social media post.
|
||||
"""
|
||||
filtered_matches, kp1, kp2 = matching_points(ref_view, view)
|
||||
MIN_MATCH_COUNT = 6
|
||||
if len(filtered_matches) < MIN_MATCH_COUNT:
|
||||
# not enough matches found
|
||||
# print("Found too few matches - {}".format(filtered_matches))
|
||||
return None
|
||||
corner = compute_crop_corner(filtered_matches, kp1, kp2)
|
||||
if corner is None:
|
||||
# no cropping corner found
|
||||
# print("Found no corner")
|
||||
return None
|
||||
v, h = corner
|
||||
# if the match is on the right-hand side of the image,
|
||||
# it is likely that there is an image to the left
|
||||
# that should not be cropped
|
||||
# in this case, we adjust the margin for the text to be
|
||||
# cropped to `correct_margin`
|
||||
# if the match is more to the left on the image, we assume
|
||||
# it starts at horizontal position 0 and do not want to
|
||||
# cut off any characters from the text
|
||||
correct_margin = 30
|
||||
if h >= view.shape[1] / 2:
|
||||
h = h - correct_margin
|
||||
else:
|
||||
h = 0
|
||||
crop_view = view[0:v, h:, :]
|
||||
return crop_view, len(filtered_matches), v, h
|
||||
|
||||
if plt_match:
|
||||
img1 = cv2.cvtColor(ref_view, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
def crop_posts_from_refs(
|
||||
ref_views: List,
|
||||
view: np.ndarray,
|
||||
plt_match: bool = False,
|
||||
plt_crop: bool = False,
|
||||
plt_image: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""Crop the social media post comments from the image.
|
||||
|
||||
Args:
|
||||
ref_views (list): List of all the reference images (as numpy arrays) that signify
|
||||
below which regions should be cropped.
|
||||
view (np.ndarray): The image to crop.
|
||||
Returns:
|
||||
np.ndarray: The cropped social media post.
|
||||
"""
|
||||
crop_view = None
|
||||
# initialize the number of found matches per reference to zero
|
||||
# so later we can select the reference with the most matches
|
||||
max_matchs = 0
|
||||
rte = None
|
||||
found_match = False
|
||||
for ref_view in ref_views:
|
||||
rte = crop_posts_image(ref_view, view)
|
||||
if rte is not None:
|
||||
crop_img, match_num, v, h = rte
|
||||
if match_num > max_matchs:
|
||||
# find the reference with the most matches to crop accordingly
|
||||
crop_view = crop_img
|
||||
final_ref = ref_view
|
||||
final_v = v
|
||||
final_h = h
|
||||
max_matchs = match_num
|
||||
found_match = True
|
||||
|
||||
if found_match and plt_match:
|
||||
# plot the match
|
||||
filtered_matches, kp1, kp2 = matching_points(final_ref, view)
|
||||
img1 = cv2.cvtColor(final_ref, cv2.COLOR_BGR2GRAY)
|
||||
img2 = cv2.cvtColor(view, cv2.COLOR_BGR2GRAY)
|
||||
draw_matches(filtered_matches, img1, img2, kp1, kp2)
|
||||
|
||||
corner = compute_crop_corner(filtered_matches, kp1, kp2)
|
||||
if corner is None:
|
||||
return None
|
||||
v, h = corner
|
||||
if view.shape[1] - h > correct_margin:
|
||||
h = view.shape[1] - ref_view.shape[1]
|
||||
if view.shape[1] - h < ref_view.shape[1]:
|
||||
h = view.shape[1] - ref_view.shape[1]
|
||||
|
||||
crop_view = view[0:v, h:, :]
|
||||
if plt_crop:
|
||||
view[v, :, 0:3] = [255, 0, 0]
|
||||
view[:, h, 0:3] = [255, 0, 0]
|
||||
plt.imshow(view)
|
||||
if found_match and plt_crop:
|
||||
# plot the cropped image
|
||||
view2 = view.copy()
|
||||
view2[final_v, :, 0:3] = [255, 0, 0]
|
||||
view2[:, final_h, 0:3] = [255, 0, 0]
|
||||
plt.imshow(cv2.cvtColor(view2, cv2.COLOR_BGR2RGB))
|
||||
plt.show()
|
||||
plt.imshow(cv2.cvtColor(crop_view, cv2.COLOR_BGR2RGB))
|
||||
plt.show()
|
||||
|
||||
plt.imshow(crop_view)
|
||||
plt.show()
|
||||
|
||||
return crop_view, len(filtered_matches)
|
||||
|
||||
|
||||
def get_file_list(dir, filelist, ext=None, convert_unix=True):
|
||||
"""
|
||||
get file lists from dir and sub dirs
|
||||
dir: root dir for file lists
|
||||
ext: file extension
|
||||
rte: File list
|
||||
"""
|
||||
if os.path.isfile(dir):
|
||||
if ext is None:
|
||||
filelist.append(dir)
|
||||
else:
|
||||
if ext in dir[-3:]:
|
||||
filelist.append(dir)
|
||||
|
||||
elif os.path.isdir(dir):
|
||||
for s in os.listdir(dir):
|
||||
new_dir = os.path.join(dir, s)
|
||||
get_file_list(new_dir, filelist, ext)
|
||||
|
||||
if convert_unix:
|
||||
new_filelist = []
|
||||
for file_ in filelist:
|
||||
file_ = file_.replace("\\", "/")
|
||||
new_filelist.append(file_)
|
||||
return new_filelist
|
||||
else:
|
||||
return filelist
|
||||
|
||||
|
||||
def crop_posts_from_refs(ref_views, view, plt_match=False, plt_crop=False):
|
||||
crop_view = None
|
||||
max_matchs = 0
|
||||
for ref_view in ref_views:
|
||||
rte = crop_posts_image(ref_view, view, plt_match=plt_match, plt_crop=plt_crop)
|
||||
if rte is not None:
|
||||
crop_img, match_num = rte
|
||||
if match_num > max_matchs:
|
||||
crop_view = crop_img
|
||||
max_matchs = match_num
|
||||
# print("match_num = ", match_num)
|
||||
|
||||
if found_match and final_h >= view.shape[1] / 2:
|
||||
# here it would crop the actual image from the social media post
|
||||
# to avoid this, we check the position from where it would crop
|
||||
# if > than half of the width of the image, also keep all that is
|
||||
# on the left-hand side of the crop
|
||||
crop_post = crop_image_from_post(view, final_h)
|
||||
if plt_image:
|
||||
# plot the image part of the social media post
|
||||
plt.imshow(cv2.cvtColor(crop_post, cv2.COLOR_BGR2RGB))
|
||||
plt.show()
|
||||
# now concatenate the image and the text part
|
||||
crop_view = paste_image_and_comment(crop_post, crop_view)
|
||||
return crop_view
|
||||
|
||||
|
||||
def crop_posts_from_files(
|
||||
ref_dir, crop_dir, save_crop_dir, plt_match=False, plt_crop=False
|
||||
):
|
||||
ref_list = []
|
||||
ref_list = get_file_list(ref_dir, ref_list, ext="png")
|
||||
ref_views = []
|
||||
for ref_file in ref_list:
|
||||
ref_view = np.array(Image.open(ref_file))
|
||||
ref_views.append(ref_view)
|
||||
crop_list = []
|
||||
crop_list = get_file_list(crop_dir, crop_list, ext="png")
|
||||
def crop_image_from_post(view: np.ndarray, final_h: int) -> np.ndarray:
|
||||
"""Crop the image part from the social media post.
|
||||
|
||||
for crop_file in crop_list:
|
||||
view = np.array(Image.open(crop_file))
|
||||
Args:
|
||||
view (np.ndarray): The image to be cropped.
|
||||
final_h: The horizontal position up to which should be cropped.
|
||||
Returns:
|
||||
np.ndarray: The cropped image part."""
|
||||
crop_post = view[:, 0:final_h, :]
|
||||
return crop_post
|
||||
|
||||
|
||||
def paste_image_and_comment(crop_post: np.ndarray, crop_view: np.ndarray) -> np.ndarray:
|
||||
"""Paste the image part and the text part together without the unecessary comments.
|
||||
|
||||
Args:
|
||||
crop_post (np.ndarray): The cropped image part of the social media post.
|
||||
crop_view (np.ndarray): The cropped text part of the social media post.
|
||||
Returns:
|
||||
np.ndarray: The image and text part of the social media post in one image."""
|
||||
h1, w1 = crop_post.shape[:2]
|
||||
h2, w2 = crop_view.shape[:2]
|
||||
image_all = np.zeros((max(h1, h2), w1 + w2, 3), np.uint8)
|
||||
image_all[:h1, :w1, :3] = crop_post
|
||||
image_all[:h2, w1 : w1 + w2, :3] = crop_view
|
||||
return image_all
|
||||
|
||||
|
||||
def crop_media_posts(
|
||||
files, ref_files, save_crop_dir, plt_match=False, plt_crop=False, plt_image=False
|
||||
) -> None:
|
||||
"""Crop social media posts so that comments beyond the first comment/post are cut off.
|
||||
|
||||
Args:
|
||||
files (list): List of all the files to be cropped.
|
||||
ref_files (list): List of all the reference images that signify
|
||||
below which regions should be cropped.
|
||||
save_crop_dir (str): Directory where to write the cropped social media posts to.
|
||||
plt_match (Bool, optional): Display the matched areas on the social media post.
|
||||
Defaults to False.
|
||||
plt_crop (Bool, optional): Display the cropped text part of the social media post.
|
||||
Defaults to False.
|
||||
plt_image (Bool, optional): Display the image part of the social media post.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
# get the reference images with regions that signify areas to crop
|
||||
ref_views = []
|
||||
for ref_file in ref_files:
|
||||
ref_view = cv2.imread(ref_file)
|
||||
ref_views.append(ref_view)
|
||||
|
||||
# parse through the social media posts to be cropped
|
||||
for crop_file in files:
|
||||
view = cv2.imread(crop_file)
|
||||
print("Doing file {}".format(crop_file))
|
||||
crop_view = crop_posts_from_refs(
|
||||
ref_views, view, plt_match=plt_match, plt_crop=plt_crop
|
||||
ref_views,
|
||||
view,
|
||||
plt_match=plt_match,
|
||||
plt_crop=plt_crop,
|
||||
plt_image=plt_image,
|
||||
)
|
||||
if crop_view is not None:
|
||||
# save the image to the provided folder
|
||||
filename = ntpath.basename(crop_file)
|
||||
save_path = os.path.join(save_crop_dir, filename)
|
||||
save_path = save_path.replace("\\", "/")
|
||||
cv2.imwrite(save_path, crop_view)
|
||||
|
||||
|
||||
def test_crop_from_file():
|
||||
# Load images
|
||||
view1 = np.array(Image.open("data/ref/ref-06.png"))
|
||||
view2 = np.array(Image.open("data/napsa/102956_eng.png"))
|
||||
crop_view, _ = crop_posts_image(view1, view2, plt_match=True, plt_crop=True)
|
||||
cv2.imwrite("data/crop_100489_ind.png", crop_view)
|
||||
|
||||
|
||||
def test_crop_from_folder():
|
||||
ref_dir = "./data/ref"
|
||||
crop_dir = "./data/apsa"
|
||||
save_crop_dir = "data/crop"
|
||||
crop_posts_from_files(
|
||||
ref_dir, crop_dir, save_crop_dir, plt_match=False, plt_crop=False
|
||||
)
|
||||
|
||||
Двоичные данные
ammico/data/ref/ref-00.png
Обычный файл
|
После Ширина: | Высота: | Размер: 6.1 KiB |
Двоичные данные
ammico/data/ref/ref-01.png
Обычный файл
|
После Ширина: | Высота: | Размер: 2.5 KiB |
Двоичные данные
ammico/data/ref/ref-02.png
Обычный файл
|
После Ширина: | Высота: | Размер: 4.5 KiB |
Двоичные данные
ammico/data/ref/ref-03.png
Обычный файл
|
После Ширина: | Высота: | Размер: 3.3 KiB |
Двоичные данные
ammico/data/ref/ref-04.png
Обычный файл
|
После Ширина: | Высота: | Размер: 4.1 KiB |
Двоичные данные
ammico/data/ref/ref-05.png
Обычный файл
|
После Ширина: | Высота: | Размер: 3.3 KiB |
Двоичные данные
ammico/data/ref/ref-06.png
Обычный файл
|
После Ширина: | Высота: | Размер: 5.2 KiB |
Двоичные данные
ammico/data/ref/ref-07.png
Обычный файл
|
После Ширина: | Высота: | Размер: 1.3 KiB |
Двоичные данные
ammico/data/ref/ref-08.png
Обычный файл
|
После Ширина: | Высота: | Размер: 14 KiB |
Двоичные данные
ammico/data/ref/ref-09.png
Обычный файл
|
После Ширина: | Высота: | Размер: 6.4 KiB |
Двоичные данные
ammico/data/ref/ref-10.png
Обычный файл
|
После Ширина: | Высота: | Размер: 17 KiB |
Двоичные данные
ammico/data/ref/ref-11.png
Обычный файл
|
После Ширина: | Высота: | Размер: 18 KiB |
Двоичные данные
ammico/data/ref/ref-12.png
Обычный файл
|
После Ширина: | Высота: | Размер: 7.9 KiB |
Двоичные данные
ammico/data/ref/ref-13.png
Обычный файл
|
После Ширина: | Высота: | Размер: 7.3 KiB |
Двоичные данные
ammico/data/test-crop-image.png
Обычный файл
|
После Ширина: | Высота: | Размер: 1.4 MiB |
@ -1,62 +1,86 @@
|
||||
import ammico.cropposts as crpo
|
||||
import cv2
|
||||
import pytest
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
TEST_IMAGE_1 = "pic1.png"
|
||||
TEST_IMAGE_2 = "pic2.png"
|
||||
|
||||
|
||||
def test_matching_points(get_path):
|
||||
ref_view = np.array(Image.open(get_path + TEST_IMAGE_2))
|
||||
view = np.array(Image.open(get_path + TEST_IMAGE_1))
|
||||
filtered_matches, _, _ = crpo.matching_points(ref_view, view)
|
||||
@pytest.fixture
|
||||
def open_images(get_path):
|
||||
ref_view = cv2.imread(get_path + TEST_IMAGE_2)
|
||||
view = cv2.imread(get_path + TEST_IMAGE_1)
|
||||
return ref_view, view
|
||||
|
||||
|
||||
def test_matching_points(open_images):
|
||||
filtered_matches, _, _ = crpo.matching_points(open_images[0], open_images[1])
|
||||
assert len(filtered_matches) > 0
|
||||
|
||||
|
||||
def test_kp_from_matches(get_path):
|
||||
ref_view = np.array(Image.open(get_path + TEST_IMAGE_2))
|
||||
view = np.array(Image.open(get_path + TEST_IMAGE_1))
|
||||
filtered_matches, kp1, kp2 = crpo.matching_points(ref_view, view)
|
||||
def test_kp_from_matches(open_images):
|
||||
filtered_matches, kp1, kp2 = crpo.matching_points(open_images[0], open_images[1])
|
||||
kp1, kp2 = crpo.kp_from_matches(filtered_matches, kp1, kp2)
|
||||
|
||||
assert kp1.shape[0] == len(filtered_matches)
|
||||
assert kp2.shape[0] == len(filtered_matches)
|
||||
assert kp1.shape[1] == 2
|
||||
assert kp2.shape[1] == 2
|
||||
|
||||
|
||||
def test_compute_crop_corner(get_path):
|
||||
ref_view = np.array(Image.open(get_path + TEST_IMAGE_2))
|
||||
view = np.array(Image.open(get_path + TEST_IMAGE_1))
|
||||
filtered_matches, kp1, kp2 = crpo.matching_points(ref_view, view)
|
||||
def test_compute_crop_corner(open_images):
|
||||
filtered_matches, kp1, kp2 = crpo.matching_points(open_images[0], open_images[1])
|
||||
corner = crpo.compute_crop_corner(filtered_matches, kp1, kp2)
|
||||
print(view.shape)
|
||||
print(corner)
|
||||
assert corner is not None
|
||||
v, h = corner
|
||||
assert 0 <= v < view.shape[0]
|
||||
assert 0 <= h < view.shape[0]
|
||||
assert 0 <= v < open_images[1].shape[0]
|
||||
assert 0 <= h < open_images[1].shape[0]
|
||||
|
||||
|
||||
def test_crop_posts_image(get_path):
|
||||
ref_view = np.array(Image.open(get_path + TEST_IMAGE_2))
|
||||
view = np.array(Image.open(get_path + TEST_IMAGE_1))
|
||||
rte = crpo.crop_posts_image(ref_view, view)
|
||||
def test_crop_posts_image(open_images):
|
||||
rte = crpo.crop_posts_image(open_images[0], open_images[1])
|
||||
assert rte is not None
|
||||
crop_view, match_num = rte
|
||||
crop_view, match_num, _, _ = rte
|
||||
assert match_num > 0
|
||||
assert crop_view.shape[0] * crop_view.shape[1] <= view.shape[0] * view.shape[1]
|
||||
assert (
|
||||
crop_view.shape[0] * crop_view.shape[1]
|
||||
<= open_images[1].shape[0] * open_images[1].shape[1]
|
||||
)
|
||||
|
||||
|
||||
def test_crop_posts_from_refs(get_path):
|
||||
ref_view = np.array(Image.open(get_path + TEST_IMAGE_2))
|
||||
view = np.array(Image.open(get_path + TEST_IMAGE_1))
|
||||
ref_views = [ref_view]
|
||||
crop_view = crpo.crop_posts_from_refs(ref_views, view)
|
||||
assert crop_view.shape[0] * crop_view.shape[1] <= view.shape[0] * view.shape[1]
|
||||
def test_crop_posts_from_refs(open_images):
|
||||
crop_view = crpo.crop_posts_from_refs([open_images[0]], open_images[1])
|
||||
assert (
|
||||
crop_view.shape[0] * crop_view.shape[1]
|
||||
<= open_images[1].shape[0] * open_images[1].shape[1]
|
||||
)
|
||||
|
||||
|
||||
def test_get_file_list(get_path):
|
||||
ref_list = []
|
||||
ref_list = crpo.get_file_list(get_path, ref_list, ext="png")
|
||||
assert len(ref_list) > 0
|
||||
def test_crop_image_from_post(open_images):
|
||||
crop_post = crpo.crop_image_from_post(open_images[0], 4)
|
||||
ref_array = np.array(
|
||||
[[220, 202, 155], [221, 204, 155], [221, 204, 155], [221, 204, 155]],
|
||||
dtype=np.uint8,
|
||||
)
|
||||
assert np.array_equal(crop_post[0], ref_array)
|
||||
|
||||
|
||||
def test_paste_image_and_comment(open_images):
|
||||
full_post = crpo.paste_image_and_comment(open_images[0], open_images[1])
|
||||
ref_array1 = np.array([220, 202, 155], dtype=np.uint8)
|
||||
ref_array2 = np.array([74, 76, 64], dtype=np.uint8)
|
||||
assert np.array_equal(full_post[0, 0], ref_array1)
|
||||
assert np.array_equal(full_post[-1, -1], ref_array2)
|
||||
|
||||
|
||||
def test_crop_media_posts(get_path, tmp_path):
|
||||
files = [get_path + TEST_IMAGE_1]
|
||||
ref_files = [get_path + TEST_IMAGE_2]
|
||||
crpo.crop_media_posts(files, ref_files, tmp_path)
|
||||
assert len(list(tmp_path.iterdir())) == 1
|
||||
# now check that image in tmp_path is the cropped one
|
||||
filename = tmp_path / "pic1.png"
|
||||
cropped_image = cv2.imread(str(filename))
|
||||
ref = np.array([222, 205, 156], dtype=np.uint8)
|
||||
assert np.array_equal(cropped_image[0, 0], ref)
|
||||
|
||||
1
notebooks/.~lock.data_out.csv#
сгенерированный
@ -1 +0,0 @@
|
||||
,iulusoy,ssc08,03.05.2023 12:16,file:///home/iulusoy/.config/libreoffice/4;
|
||||
140
notebooks/cropposts.ipynb
сгенерированный
@ -1,6 +1,7 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "b25986d7",
|
||||
"metadata": {},
|
||||
@ -9,6 +10,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "c8a5a491",
|
||||
"metadata": {},
|
||||
@ -17,6 +19,35 @@
|
||||
"We can set some manually cropped views from social media posts as reference for cropping the same type social media posts images."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "70ffb7e2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Please ignore this cell: extra install steps that are only executed when running the notebook on Google Colab\n",
|
||||
"# flake8-noqa-cell\n",
|
||||
"import os\n",
|
||||
"if 'google.colab' in str(get_ipython()):\n",
|
||||
" # we're running on colab\n",
|
||||
" # first install pinned version of setuptools (latest version doesn't seem to work with this package on colab)\n",
|
||||
" %pip install setuptools==61 -qqq\n",
|
||||
" # install the moralization package\n",
|
||||
" %pip install git+https://github.com/ssciwr/AMMICO.git -qqq\n",
|
||||
"\n",
|
||||
" # prevent loading of the wrong opencv library\n",
|
||||
" %pip uninstall -y opencv-contrib-python\n",
|
||||
" %pip install opencv-contrib-python\n",
|
||||
"\n",
|
||||
" from google.colab import drive\n",
|
||||
" drive.mount('/content/drive')\n",
|
||||
"\n",
|
||||
" if not os.path.isdir('/content/ref'):\n",
|
||||
" !wget https://github.com/ssciwr/AMMICO/archive/refs/heads/ref-data.zip -q\n",
|
||||
" !unzip -qq ref-data.zip -d . && mv -f AMMICO-ref-data/data/ref . && rm -rf AMMICO-ref-data ref-data.zip"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@ -25,10 +56,20 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import ammico.cropposts as crpo\n",
|
||||
"import numpy as np\n",
|
||||
"import ammico.utils as utils\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from PIL import Image\n",
|
||||
"import cv2"
|
||||
"import cv2\n",
|
||||
"import importlib_resources\n",
|
||||
"pkg = importlib_resources.files(\"ammico\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "e7b8127f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The cropping is carried out by finding reference images on the image to be cropped. If a reference matches a region on the image, then everything below the matched region is removed. Manually look at a reference and an example post with the code below."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -39,16 +80,31 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# load ref view for cropping the same type social media posts images.\n",
|
||||
"ref_view = np.array(Image.open(\"../data/ref/ref-00.png\"))\n",
|
||||
"plt.imshow(ref_view)\n",
|
||||
"# substitute the below paths for your samples\n",
|
||||
"path_ref = pkg / \"data\" / \"ref\" / \"ref-00.png\"\n",
|
||||
"ref_view = cv2.imread(path_ref)\n",
|
||||
"RGB_ref_view = cv2.cvtColor(ref_view, cv2.COLOR_BGR2RGB)\n",
|
||||
"plt.figure(figsize=(10, 15))\n",
|
||||
"plt.imshow(RGB_ref_view)\n",
|
||||
"plt.show()\n",
|
||||
"\n",
|
||||
"view = np.array(Image.open(\"../data/all/102790S_eng.png\"))\n",
|
||||
"path_post = pkg / \"data\" / \"test-crop-image.png\"\n",
|
||||
"view = cv2.imread(path_post)\n",
|
||||
"RGB_view = cv2.cvtColor(view, cv2.COLOR_BGR2RGB)\n",
|
||||
"plt.figure(figsize=(10, 15))\n",
|
||||
"plt.imshow(view)\n",
|
||||
"plt.imshow(RGB_view)\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "49a11f61",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can now crop the image and check on the way that everything looks fine. `plt_match` will plot the matches on the image and below which line content will be cropped; `plt_crop` will plot the cropped text part of the social media post with the comments removed; `plt_image` will plot the image part of the social media post if applicable."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@ -56,42 +112,25 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# crop a posts from reference view, set plt_match=True, plt_crop=True\n",
|
||||
"crop_view, match_num = crpo.crop_posts_image(\n",
|
||||
" ref_view, view, plt_match=True, plt_crop=True\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# save cropped images\n",
|
||||
"cv2.imwrite(\"test.png\", crop_view)\n",
|
||||
"print(match_num)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "67fc7b82",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"view2 = np.array(Image.open(\"../data/all/102790S_eng.png\"))\n",
|
||||
"plt.figure(figsize=(10, 15))\n",
|
||||
"plt.imshow(view2)\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "21d87359",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# crop a posts from the same reference view, set plt_match=True, plt_crop=True\n",
|
||||
"crop_view, match_num = crpo.crop_posts_image(\n",
|
||||
" ref_view, view2, plt_match=True, plt_crop=True\n",
|
||||
"# crop a posts from reference view, check the cropping \n",
|
||||
"# this will only plot something if the reference is found on the image\n",
|
||||
"crop_view = crpo.crop_posts_from_refs(\n",
|
||||
" [ref_view], view, \n",
|
||||
" plt_match=True, plt_crop=True, plt_image=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "1929e549",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Batch crop images from the image folder given in `crop_dir`. The cropped images will save in `save_crop_dir` folder with the same file name as the original file. The reference images with the items to match are provided in `ref_dir`.\n",
|
||||
"\n",
|
||||
"Sometimes the cropping will be imperfect, due to improper matches on the image. It is sometimes easier to first categorize the social media posts and then set different references in the reference folder `ref_dir`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@ -99,23 +138,22 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Batch crop images from image folders.\n",
|
||||
"# The cropped images will save in save_crop_dir folders with the same file name form origin.\n",
|
||||
"# We can set many types reference images in ref_dir folder, to crop posts images in different types social media, like twitter or facebook.\n",
|
||||
"ref_dir = \"../data/ref\"\n",
|
||||
"crop_dir = \"../data/all\"\n",
|
||||
"save_crop_dir = \"../data/crop\"\n",
|
||||
"crpo.crop_posts_from_files(\n",
|
||||
" ref_dir, crop_dir, save_crop_dir, plt_match=False, plt_crop=False\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"Batch cropping images well done\")"
|
||||
"crop_dir = \"../ammico/data/\"\n",
|
||||
"ref_dir = \"../ammico/data/ref\"\n",
|
||||
"save_crop_dir = \"data/crop/\"\n",
|
||||
"\n",
|
||||
"files = utils.find_files(path=crop_dir,limit=10,)\n",
|
||||
"ref_files = utils.find_files(path=ref_dir, limit=100)\n",
|
||||
"\n",
|
||||
"crpo.crop_media_posts(files, ref_files, save_crop_dir, plt_match=True, plt_crop=False, plt_image=False)\n",
|
||||
"print(\"Batch cropping images done\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c10d9f6f",
|
||||
"id": "b3b3c1ad",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
|
||||
@ -24,13 +24,16 @@ classifiers = [
|
||||
dependencies = [
|
||||
"bertopic<=0.14.1",
|
||||
"cvlib",
|
||||
"dash",
|
||||
"dash_renderjson",
|
||||
"deepface<=0.0.75",
|
||||
"googletrans==3.1.0a0",
|
||||
"google-cloud-vision",
|
||||
"grpcio",
|
||||
"importlib_metadata",
|
||||
"ipython",
|
||||
"ipywidgets<8.0.5",
|
||||
"ipykernel",
|
||||
"jupyter_dash",
|
||||
"matplotlib",
|
||||
"numpy<=1.23.4",
|
||||
"pandas",
|
||||
@ -39,22 +42,16 @@ dependencies = [
|
||||
"protobuf",
|
||||
"pytest",
|
||||
"pytest-cov",
|
||||
"pytest-xdist",
|
||||
"requests",
|
||||
"Requests",
|
||||
"retina_face",
|
||||
"salesforce-lavis",
|
||||
"setuptools",
|
||||
"spacy",
|
||||
"spacytextblob",
|
||||
"tensorflow",
|
||||
"textblob",
|
||||
"torch",
|
||||
"transformers",
|
||||
"google-cloud-vision",
|
||||
"setuptools",
|
||||
"opencv-contrib-python",
|
||||
"dash",
|
||||
"jupyter_dash",
|
||||
"dash_renderjson",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@ -62,3 +59,7 @@ ammico_prefetch_models = "ammico.utils:ammico_prefetch_models"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["ammico"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
# Include any png files found in the "data" subdirectory of "ammico"
|
||||
"ammico.data" = ["*.png"]
|
||||